mirror of
1
Fork 0

[performance] refactoring + add fave / follow / request / visibility caching (#1607)

* refactor visibility checking, add caching for visibility

* invalidate visibility cache items on account / status deletes

* fix requester ID passed to visibility cache nil ptr

* de-interface caches, fix home / public timeline caching + visibility

* finish adding code comments for visibility filter

* fix angry goconst linter warnings

* actually finish adding filter visibility code comments for timeline functions

* move home timeline status author check to after visibility

* remove now-unused code

* add more code comments

* add TODO code comment, update printed cache start names

* update printed cache names on stop

* start adding separate follow(request) delete db functions, add specific visibility cache tests

* add relationship type caching

* fix getting local account follows / followed-bys, other small codebase improvements

* simplify invalidation using cache hooks, add more GetAccountBy___() functions

* fix boosting to return 404 if not boostable but no error (to not leak status ID)

* remove dead code

* improved placement of cache invalidation

* update license headers

* add example follow, follow-request config entries

* add example visibility cache configuration to config file

* use specific PutFollowRequest() instead of just Put()

* add tests for all GetAccountBy()

* add GetBlockBy() tests

* update block to check primitive fields

* update and finish adding Get{Account,Block,Follow,FollowRequest}By() tests

* fix copy-pasted code

* update envparsing test

* whitespace

* fix bun struct tag

* add license header to gtscontext

* fix old license header

* improved error creation to not use fmt.Errorf() when not needed

* fix various rebase conflicts, fix account test

* remove commented-out code, fix-up mention caching

* fix mention select bun statement

* ensure mention target account populated, pass in context to customrenderer logging

* remove more uncommented code, fix typeutil test

* add statusfave database model caching

* add status fave cache configuration

* add status fave cache example config

* woops, catch missed error. nice catch linter!

* add back testrig panic on nil db

* update example configuration to match defaults, slight tweak to cache configuration defaults

* update envparsing test with new defaults

* fetch followingget to use the follow target account

* use accounnt.IsLocal() instead of empty domain check

* use constants for the cache visibility type check

* use bun.In() for notification type restriction in db query

* include replies when fetching PublicTimeline() (to account for single-author threads in Visibility{}.StatusPublicTimelineable())

* use bun query building for nested select statements to ensure working with postgres

* update public timeline future status checks to match visibility filter

* same as previous, for home timeline

* update public timeline tests to dynamically check for appropriate statuses

* migrate accounts to allow unique constraint on public_key

* provide minimal account with publicKey

---------

Signed-off-by: kim <grufwub@gmail.com>
Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
This commit is contained in:
kim 2023-03-28 14:03:14 +01:00 committed by GitHub
parent 7d09863393
commit de6e3e5f2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
100 changed files with 4423 additions and 2367 deletions

View File

@ -230,71 +230,95 @@ db-sqlite-cache-size: "8MiB"
db-sqlite-busy-timeout: "5m"
cache:
# Cache configuration options:
#
# max-size = maximum cached objects count
# ttl = cached object lifetime
# sweep-freq = frequency to look for stale cache objects
# (zero will disable cache sweeping)
#############################
#### VISIBILITY CACHES ######
#############################
#
# Configure Status and account
# visibility cache.
visibility-max-size: 2000
visibility-ttl: "30m"
visibility-sweep-freq: "1m"
gts:
###########################
#### DATABASE CACHES ######
###########################
#
# Database cache configuration:
#
# Allows configuration of caches used
# when loading GTS models from the database.
#
# max-size = maximum cached objects count
# ttl = cached object lifetime
# sweep-freq = frequency to look for stale cache objects
# Configure GTS database
# model caches.
account-max-size: 500
account-ttl: "5m"
account-sweep-freq: "30s"
account-max-size: 2000
account-ttl: "30m"
account-sweep-freq: "1m"
block-max-size: 100
block-ttl: "5m"
block-sweep-freq: "30s"
block-ttl: "30m"
block-sweep-freq: "1m"
domain-block-max-size: 1000
domain-block-max-size: 2000
domain-block-ttl: "24h"
domain-block-sweep-freq: "1m"
emoji-max-size: 500
emoji-ttl: "5m"
emoji-sweep-freq: "30s"
emoji-max-size: 2000
emoji-ttl: "30m"
emoji-sweep-freq: "1m"
emoji-category-max-size: 100
emoji-category-ttl: "5m"
emoji-category-sweep-freq: "30s"
emoji-category-ttl: "30m"
emoji-category-sweep-freq: "1m"
media-max-size: 500
media-ttl: "5m"
media-sweep-freq: "30s"
follow-max-size: 2000
follow-ttl: "30m"
follow-sweep-freq: "1m"
mention-max-size: 500
mention-ttl: "5m"
mention-sweep-freq: "30s"
follow-request-max-size: 2000
follow-request-ttl: "30m"
follow-request-sweep-freq: "1m"
notification-max-size: 500
notification-ttl: "5m"
notification-sweep-freq: "30s"
media-max-size: 1000
media-ttl: "30m"
media-sweep-freq: "1m"
mention-max-size: 2000
mention-ttl: "30m"
mention-sweep-freq: "1m"
notification-max-size: 1000
notification-ttl: "30m"
notification-sweep-freq: "1m"
report-max-size: 100
report-ttl: "5m"
report-sweep-freq: "30s"
report-ttl: "30m"
report-sweep-freq: "1m"
status-max-size: 500
status-ttl: "5m"
status-sweep-freq: "30s"
status-max-size: 2000
status-ttl: "30m"
status-sweep-freq: "1m"
tombstone-max-size: 100
tombstone-ttl: "5m"
tombstone-sweep-freq: "30s"
status-fave-max-size: 2000
status-fave-ttl: "30m"
status-fave-sweep-freq: "1m"
user-max-size: 100
user-ttl: "5m"
user-sweep-freq: "30s"
tombstone-max-size: 500
tombstone-ttl: "30m"
tombstone-sweep-freq: "1m"
user-max-size: 500
user-ttl: "30m"
user-sweep-freq: "1m"
webfinger-max-size": 250
webfinger-ttl: "24h"
webfinger-sweep-freq": "15m"
webfinger-sweep-freq": "1m"
######################
##### WEB CONFIG #####

View File

@ -168,7 +168,7 @@ func (suite *StatusBoostTestSuite) TestPostBoostOwnFollowersOnly() {
suite.Equal("really cool gts application", responseStatus.Reblog.Application.Name)
}
// try to boost a status that's not boostable
// try to boost a status that's not boostable / visible to us
func (suite *StatusBoostTestSuite) TestPostUnboostable() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.DBTokenToToken(t)
@ -197,13 +197,13 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() {
suite.statusModule.StatusBoostPOSTHandler(ctx)
// check response
suite.Equal(http.StatusForbidden, recorder.Code) // we 403 unboostable statuses
suite.Equal(http.StatusNotFound, recorder.Code) // we 404 unboostable statuses
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
suite.NoError(err)
suite.Equal(`{"error":"Forbidden"}`, string(b))
suite.Equal(`{"error":"Not Found"}`, string(b))
}
// try to boost a status that's not visible to the user

23
internal/cache/ap.go vendored
View File

@ -17,27 +17,14 @@
package cache
type APCaches interface {
type APCaches struct{}
// Init will initialize all the ActivityPub caches in this collection.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
Init()
func (c *APCaches) Init() {}
// Start will attempt to start all of the ActivityPub caches, or panic.
Start()
func (c *APCaches) Start() {}
// Stop will attempt to stop all of the ActivityPub caches, or panic.
Stop()
}
// NewAP returns a new default implementation of APCaches.
func NewAP() APCaches {
return &apCaches{}
}
type apCaches struct{}
func (c *apCaches) Init() {}
func (c *apCaches) Start() {}
func (c *apCaches) Stop() {}
func (c *APCaches) Stop() {}

View File

@ -17,13 +17,23 @@
package cache
import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type Caches struct {
// GTS provides access to the collection of gtsmodel object caches.
// (used by the database).
GTS GTSCaches
// AP provides access to the collection of ActivityPub object caches.
// (planned to be used by the typeconverter).
AP APCaches
// Visibility provides access to the item visibility cache.
// (used by the visibility filter).
Visibility VisibilityCache
// prevent pass-by-value.
_ nocopy
}
@ -31,29 +41,77 @@ type Caches struct {
// Init will (re)initialize both the GTS and AP cache collections.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *Caches) Init() {
if c.GTS == nil {
// use default impl
c.GTS = NewGTS()
}
if c.AP == nil {
// use default impl
c.AP = NewAP()
}
// initialize caches
c.GTS.Init()
c.AP.Init()
c.Visibility.Init()
// Setup cache invalidate hooks.
// !! READ THE METHOD COMMENT
c.setuphooks()
}
// Start will start both the GTS and AP cache collections.
func (c *Caches) Start() {
c.GTS.Start()
c.AP.Start()
c.Visibility.Start()
}
// Stop will stop both the GTS and AP cache collections.
func (c *Caches) Stop() {
c.GTS.Stop()
c.AP.Stop()
c.Visibility.Stop()
}
// setuphooks sets necessary cache invalidation hooks between caches,
// as an invalidation indicates a database UPDATE / DELETE. INSERT is
// not handled by invalidation hooks and must be invalidated manually.
func (c *Caches) setuphooks() {
c.GTS.Account().SetInvalidateCallback(func(account *gtsmodel.Account) {
// Invalidate account ID cached visibility.
c.Visibility.Invalidate("ItemID", account.ID)
c.Visibility.Invalidate("RequesterID", account.ID)
})
c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) {
// Invalidate block origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.AccountID)
c.Visibility.Invalidate("RequesterID", block.AccountID)
// Invalidate block target account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.TargetAccountID)
c.Visibility.Invalidate("RequesterID", block.TargetAccountID)
})
c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) {
// Invalidate follow origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.AccountID)
c.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow target account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.TargetAccountID)
c.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
})
c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) {
// Invalidate follow request origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", followReq.AccountID)
c.Visibility.Invalidate("RequesterID", followReq.AccountID)
// Invalidate follow request target account ID cached visibility.
c.Visibility.Invalidate("ItemID", followReq.TargetAccountID)
c.Visibility.Invalidate("RequesterID", followReq.TargetAccountID)
})
c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) {
// Invalidate status ID cached visibility.
c.Visibility.Invalidate("ItemID", status.ID)
})
c.GTS.User().SetInvalidateCallback(func(user *gtsmodel.User) {
// Invalidate local account ID cached visibility.
c.Visibility.Invalidate("ItemID", user.AccountID)
c.Visibility.Invalidate("RequesterID", user.AccountID)
})
}

303
internal/cache/gts.go vendored
View File

@ -25,240 +25,221 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type GTSCaches interface {
// Init will initialize all the gtsmodel caches in this collection.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
Init()
// Start will attempt to start all of the gtsmodel caches, or panic.
Start()
// Stop will attempt to stop all of the gtsmodel caches, or panic.
Stop()
// Account provides access to the gtsmodel Account database cache.
Account() *result.Cache[*gtsmodel.Account]
// Block provides access to the gtsmodel Block (account) database cache.
Block() *result.Cache[*gtsmodel.Block]
// DomainBlock provides access to the domain block database cache.
DomainBlock() *domain.BlockCache
// Emoji provides access to the gtsmodel Emoji database cache.
Emoji() *result.Cache[*gtsmodel.Emoji]
// EmojiCategory provides access to the gtsmodel EmojiCategory database cache.
EmojiCategory() *result.Cache[*gtsmodel.EmojiCategory]
// Mention provides access to the gtsmodel Mention database cache.
Mention() *result.Cache[*gtsmodel.Mention]
// Media provides access to the gtsmodel Media database cache.
Media() *result.Cache[*gtsmodel.MediaAttachment]
// Notification provides access to the gtsmodel Notification database cache.
Notification() *result.Cache[*gtsmodel.Notification]
// Report provides access to the gtsmodel Report database cache.
Report() *result.Cache[*gtsmodel.Report]
// Status provides access to the gtsmodel Status database cache.
Status() *result.Cache[*gtsmodel.Status]
// Tombstone provides access to the gtsmodel Tombstone database cache.
Tombstone() *result.Cache[*gtsmodel.Tombstone]
// User provides access to the gtsmodel User database cache.
User() *result.Cache[*gtsmodel.User]
// Webfinger
Webfinger() *ttl.Cache[string, string]
}
// NewGTS returns a new default implementation of GTSCaches.
func NewGTS() GTSCaches {
return &gtsCaches{}
}
type gtsCaches struct {
type GTSCaches struct {
account *result.Cache[*gtsmodel.Account]
block *result.Cache[*gtsmodel.Block]
// TODO: maybe should be moved out of here since it's
// not actually doing anything with gtsmodel.DomainBlock.
domainBlock *domain.BlockCache
emoji *result.Cache[*gtsmodel.Emoji]
emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
follow *result.Cache[*gtsmodel.Follow]
followRequest *result.Cache[*gtsmodel.FollowRequest]
media *result.Cache[*gtsmodel.MediaAttachment]
mention *result.Cache[*gtsmodel.Mention]
notification *result.Cache[*gtsmodel.Notification]
report *result.Cache[*gtsmodel.Report]
status *result.Cache[*gtsmodel.Status]
statusFave *result.Cache[*gtsmodel.StatusFave]
tombstone *result.Cache[*gtsmodel.Tombstone]
user *result.Cache[*gtsmodel.User]
// TODO: move out of GTS caches since not using database models.
webfinger *ttl.Cache[string, string]
}
func (c *gtsCaches) Init() {
// Init will initialize all the gtsmodel caches in this collection.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *GTSCaches) Init() {
c.initAccount()
c.initBlock()
c.initDomainBlock()
c.initEmoji()
c.initEmojiCategory()
c.initFollow()
c.initFollowRequest()
c.initMedia()
c.initMention()
c.initNotification()
c.initReport()
c.initStatus()
c.initStatusFave()
c.initTombstone()
c.initUser()
c.initWebfinger()
}
func (c *gtsCaches) Start() {
tryUntil("starting gtsmodel.Account cache", 5, func() bool {
return c.account.Start(config.GetCacheGTSAccountSweepFreq())
// Start will attempt to start all of the gtsmodel caches, or panic.
func (c *GTSCaches) Start() {
tryStart(c.account, config.GetCacheGTSAccountSweepFreq())
tryStart(c.block, config.GetCacheGTSBlockSweepFreq())
tryUntil("starting domain block cache", 5, func() bool {
if sweep := config.GetCacheGTSDomainBlockSweepFreq(); sweep > 0 {
return c.domainBlock.Start(sweep)
}
return true
})
tryUntil("starting gtsmodel.Block cache", 5, func() bool {
return c.block.Start(config.GetCacheGTSBlockSweepFreq())
})
tryUntil("starting gtsmodel.DomainBlock cache", 5, func() bool {
return c.domainBlock.Start(config.GetCacheGTSDomainBlockSweepFreq())
})
tryUntil("starting gtsmodel.Emoji cache", 5, func() bool {
return c.emoji.Start(config.GetCacheGTSEmojiSweepFreq())
})
tryUntil("starting gtsmodel.EmojiCategory cache", 5, func() bool {
return c.emojiCategory.Start(config.GetCacheGTSEmojiCategorySweepFreq())
})
tryUntil("starting gtsmodel.MediaAttachment cache", 5, func() bool {
return c.media.Start(config.GetCacheGTSMediaSweepFreq())
})
tryUntil("starting gtsmodel.Mention cache", 5, func() bool {
return c.mention.Start(config.GetCacheGTSMentionSweepFreq())
})
tryUntil("starting gtsmodel.Notification cache", 5, func() bool {
return c.notification.Start(config.GetCacheGTSNotificationSweepFreq())
})
tryUntil("starting gtsmodel.Report cache", 5, func() bool {
return c.report.Start(config.GetCacheGTSReportSweepFreq())
})
tryUntil("starting gtsmodel.Status cache", 5, func() bool {
return c.status.Start(config.GetCacheGTSStatusSweepFreq())
})
tryUntil("starting gtsmodel.Tombstone cache", 5, func() bool {
return c.tombstone.Start(config.GetCacheGTSTombstoneSweepFreq())
})
tryUntil("starting gtsmodel.User cache", 5, func() bool {
return c.user.Start(config.GetCacheGTSUserSweepFreq())
})
tryUntil("starting gtsmodel.Webfinger cache", 5, func() bool {
return c.webfinger.Start(config.GetCacheGTSWebfingerSweepFreq())
tryStart(c.emoji, config.GetCacheGTSEmojiSweepFreq())
tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStart(c.media, config.GetCacheGTSMediaSweepFreq())
tryStart(c.mention, config.GetCacheGTSMentionSweepFreq())
tryStart(c.notification, config.GetCacheGTSNotificationSweepFreq())
tryStart(c.report, config.GetCacheGTSReportSweepFreq())
tryStart(c.status, config.GetCacheGTSStatusSweepFreq())
tryStart(c.statusFave, config.GetCacheGTSStatusFaveSweepFreq())
tryStart(c.tombstone, config.GetCacheGTSTombstoneSweepFreq())
tryStart(c.user, config.GetCacheGTSUserSweepFreq())
tryUntil("starting *gtsmodel.Webfinger cache", 5, func() bool {
if sweep := config.GetCacheGTSWebfingerSweepFreq(); sweep > 0 {
return c.webfinger.Start(sweep)
}
return true
})
}
func (c *gtsCaches) Stop() {
tryUntil("stopping gtsmodel.Account cache", 5, c.account.Stop)
tryUntil("stopping gtsmodel.Block cache", 5, c.block.Stop)
tryUntil("stopping gtsmodel.DomainBlock cache", 5, c.domainBlock.Stop)
tryUntil("stopping gtsmodel.Emoji cache", 5, c.emoji.Stop)
tryUntil("stopping gtsmodel.EmojiCategory cache", 5, c.emojiCategory.Stop)
tryUntil("stopping gtsmodel.MediaAttachment cache", 5, c.media.Stop)
tryUntil("stopping gtsmodel.Mention cache", 5, c.mention.Stop)
tryUntil("stopping gtsmodel.Notification cache", 5, c.notification.Stop)
tryUntil("stopping gtsmodel.Report cache", 5, c.report.Stop)
tryUntil("stopping gtsmodel.Status cache", 5, c.status.Stop)
tryUntil("stopping gtsmodel.Tombstone cache", 5, c.tombstone.Stop)
tryUntil("stopping gtsmodel.User cache", 5, c.user.Stop)
tryUntil("stopping gtsmodel.Webfinger cache", 5, c.webfinger.Stop)
// Stop will attempt to stop all of the gtsmodel caches, or panic.
func (c *GTSCaches) Stop() {
tryStop(c.account, config.GetCacheGTSAccountSweepFreq())
tryStop(c.block, config.GetCacheGTSBlockSweepFreq())
tryUntil("stopping domain block cache", 5, c.domainBlock.Stop)
tryStop(c.emoji, config.GetCacheGTSEmojiSweepFreq())
tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStop(c.media, config.GetCacheGTSMediaSweepFreq())
tryStop(c.mention, config.GetCacheGTSNotificationSweepFreq())
tryStop(c.notification, config.GetCacheGTSNotificationSweepFreq())
tryStop(c.report, config.GetCacheGTSReportSweepFreq())
tryStop(c.status, config.GetCacheGTSStatusSweepFreq())
tryStop(c.statusFave, config.GetCacheGTSStatusFaveSweepFreq())
tryStop(c.tombstone, config.GetCacheGTSTombstoneSweepFreq())
tryStop(c.user, config.GetCacheGTSUserSweepFreq())
tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.webfinger.Stop)
}
func (c *gtsCaches) Account() *result.Cache[*gtsmodel.Account] {
// Account provides access to the gtsmodel Account database cache.
func (c *GTSCaches) Account() *result.Cache[*gtsmodel.Account] {
return c.account
}
func (c *gtsCaches) Block() *result.Cache[*gtsmodel.Block] {
// Block provides access to the gtsmodel Block (account) database cache.
func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] {
return c.block
}
func (c *gtsCaches) DomainBlock() *domain.BlockCache {
// DomainBlock provides access to the domain block database cache.
func (c *GTSCaches) DomainBlock() *domain.BlockCache {
return c.domainBlock
}
func (c *gtsCaches) Emoji() *result.Cache[*gtsmodel.Emoji] {
// Emoji provides access to the gtsmodel Emoji database cache.
func (c *GTSCaches) Emoji() *result.Cache[*gtsmodel.Emoji] {
return c.emoji
}
func (c *gtsCaches) EmojiCategory() *result.Cache[*gtsmodel.EmojiCategory] {
// EmojiCategory provides access to the gtsmodel EmojiCategory database cache.
func (c *GTSCaches) EmojiCategory() *result.Cache[*gtsmodel.EmojiCategory] {
return c.emojiCategory
}
func (c *gtsCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] {
// Follow provides access to the gtsmodel Follow database cache.
func (c *GTSCaches) Follow() *result.Cache[*gtsmodel.Follow] {
return c.follow
}
// FollowRequest provides access to the gtsmodel FollowRequest database cache.
func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {
return c.followRequest
}
// Media provides access to the gtsmodel Media database cache.
func (c *GTSCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] {
return c.media
}
func (c *gtsCaches) Mention() *result.Cache[*gtsmodel.Mention] {
// Mention provides access to the gtsmodel Mention database cache.
func (c *GTSCaches) Mention() *result.Cache[*gtsmodel.Mention] {
return c.mention
}
func (c *gtsCaches) Notification() *result.Cache[*gtsmodel.Notification] {
// Notification provides access to the gtsmodel Notification database cache.
func (c *GTSCaches) Notification() *result.Cache[*gtsmodel.Notification] {
return c.notification
}
func (c *gtsCaches) Report() *result.Cache[*gtsmodel.Report] {
// Report provides access to the gtsmodel Report database cache.
func (c *GTSCaches) Report() *result.Cache[*gtsmodel.Report] {
return c.report
}
func (c *gtsCaches) Status() *result.Cache[*gtsmodel.Status] {
// Status provides access to the gtsmodel Status database cache.
func (c *GTSCaches) Status() *result.Cache[*gtsmodel.Status] {
return c.status
}
func (c *gtsCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] {
// StatusFave provides access to the gtsmodel StatusFave database cache.
func (c *GTSCaches) StatusFave() *result.Cache[*gtsmodel.StatusFave] {
return c.statusFave
}
// Tombstone provides access to the gtsmodel Tombstone database cache.
func (c *GTSCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] {
return c.tombstone
}
func (c *gtsCaches) User() *result.Cache[*gtsmodel.User] {
// User provides access to the gtsmodel User database cache.
func (c *GTSCaches) User() *result.Cache[*gtsmodel.User] {
return c.user
}
func (c *gtsCaches) Webfinger() *ttl.Cache[string, string] {
// Webfinger provides access to the webfinger URL cache.
func (c *GTSCaches) Webfinger() *ttl.Cache[string, string] {
return c.webfinger
}
func (c *gtsCaches) initAccount() {
func (c *GTSCaches) initAccount() {
c.account = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "URL"},
{Name: "Username.Domain"},
{Name: "PublicKeyURI"},
{Name: "InboxURI"},
{Name: "OutboxURI"},
{Name: "FollowersURI"},
{Name: "FollowingURI"},
}, func(a1 *gtsmodel.Account) *gtsmodel.Account {
a2 := new(gtsmodel.Account)
*a2 = *a1
return a2
}, config.GetCacheGTSAccountMaxSize())
c.account.SetTTL(config.GetCacheGTSAccountTTL(), true)
c.account.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initBlock() {
func (c *GTSCaches) initBlock() {
c.block = result.New([]result.Lookup{
{Name: "ID"},
{Name: "AccountID.TargetAccountID"},
{Name: "URI"},
{Name: "AccountID.TargetAccountID"},
}, func(b1 *gtsmodel.Block) *gtsmodel.Block {
b2 := new(gtsmodel.Block)
*b2 = *b1
return b2
}, config.GetCacheGTSBlockMaxSize())
c.block.SetTTL(config.GetCacheGTSBlockTTL(), true)
c.block.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initDomainBlock() {
func (c *GTSCaches) initDomainBlock() {
c.domainBlock = domain.New(
config.GetCacheGTSDomainBlockMaxSize(),
config.GetCacheGTSDomainBlockTTL(),
)
}
func (c *gtsCaches) initEmoji() {
func (c *GTSCaches) initEmoji() {
c.emoji = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
@ -270,9 +251,10 @@ func (c *gtsCaches) initEmoji() {
return e2
}, config.GetCacheGTSEmojiMaxSize())
c.emoji.SetTTL(config.GetCacheGTSEmojiTTL(), true)
c.emoji.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initEmojiCategory() {
func (c *GTSCaches) initEmojiCategory() {
c.emojiCategory = result.New([]result.Lookup{
{Name: "ID"},
{Name: "Name"},
@ -282,9 +264,36 @@ func (c *gtsCaches) initEmojiCategory() {
return c2
}, config.GetCacheGTSEmojiCategoryMaxSize())
c.emojiCategory.SetTTL(config.GetCacheGTSEmojiCategoryTTL(), true)
c.emojiCategory.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initMedia() {
func (c *GTSCaches) initFollow() {
c.follow = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "AccountID.TargetAccountID"},
}, func(f1 *gtsmodel.Follow) *gtsmodel.Follow {
f2 := new(gtsmodel.Follow)
*f2 = *f1
return f2
}, config.GetCacheGTSFollowMaxSize())
c.follow.SetTTL(config.GetCacheGTSFollowTTL(), true)
}
func (c *GTSCaches) initFollowRequest() {
c.followRequest = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "AccountID.TargetAccountID"},
}, func(f1 *gtsmodel.FollowRequest) *gtsmodel.FollowRequest {
f2 := new(gtsmodel.FollowRequest)
*f2 = *f1
return f2
}, config.GetCacheGTSFollowRequestMaxSize())
c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)
}
func (c *GTSCaches) initMedia() {
c.media = result.New([]result.Lookup{
{Name: "ID"},
}, func(m1 *gtsmodel.MediaAttachment) *gtsmodel.MediaAttachment {
@ -293,9 +302,10 @@ func (c *gtsCaches) initMedia() {
return m2
}, config.GetCacheGTSMediaMaxSize())
c.media.SetTTL(config.GetCacheGTSMediaTTL(), true)
c.media.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initMention() {
func (c *GTSCaches) initMention() {
c.mention = result.New([]result.Lookup{
{Name: "ID"},
}, func(m1 *gtsmodel.Mention) *gtsmodel.Mention {
@ -304,9 +314,10 @@ func (c *gtsCaches) initMention() {
return m2
}, config.GetCacheGTSMentionMaxSize())
c.mention.SetTTL(config.GetCacheGTSMentionTTL(), true)
c.mention.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initNotification() {
func (c *GTSCaches) initNotification() {
c.notification = result.New([]result.Lookup{
{Name: "ID"},
}, func(n1 *gtsmodel.Notification) *gtsmodel.Notification {
@ -315,9 +326,10 @@ func (c *gtsCaches) initNotification() {
return n2
}, config.GetCacheGTSNotificationMaxSize())
c.notification.SetTTL(config.GetCacheGTSNotificationTTL(), true)
c.notification.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initReport() {
func (c *GTSCaches) initReport() {
c.report = result.New([]result.Lookup{
{Name: "ID"},
}, func(r1 *gtsmodel.Report) *gtsmodel.Report {
@ -326,9 +338,10 @@ func (c *gtsCaches) initReport() {
return r2
}, config.GetCacheGTSReportMaxSize())
c.report.SetTTL(config.GetCacheGTSReportTTL(), true)
c.report.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initStatus() {
func (c *GTSCaches) initStatus() {
c.status = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
@ -339,10 +352,24 @@ func (c *gtsCaches) initStatus() {
return s2
}, config.GetCacheGTSStatusMaxSize())
c.status.SetTTL(config.GetCacheGTSStatusTTL(), true)
c.status.IgnoreErrors(ignoreErrors)
}
func (c *GTSCaches) initStatusFave() {
c.statusFave = result.New([]result.Lookup{
{Name: "ID"},
{Name: "AccountID.StatusID"},
}, func(f1 *gtsmodel.StatusFave) *gtsmodel.StatusFave {
f2 := new(gtsmodel.StatusFave)
*f2 = *f1
return f2
}, config.GetCacheGTSStatusFaveMaxSize())
c.status.SetTTL(config.GetCacheGTSStatusFaveTTL(), true)
c.status.IgnoreErrors(ignoreErrors)
}
// initTombstone will initialize the gtsmodel.Tombstone cache.
func (c *gtsCaches) initTombstone() {
func (c *GTSCaches) initTombstone() {
c.tombstone = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
@ -352,9 +379,10 @@ func (c *gtsCaches) initTombstone() {
return t2
}, config.GetCacheGTSTombstoneMaxSize())
c.tombstone.SetTTL(config.GetCacheGTSTombstoneTTL(), true)
c.tombstone.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initUser() {
func (c *GTSCaches) initUser() {
c.user = result.New([]result.Lookup{
{Name: "ID"},
{Name: "AccountID"},
@ -367,9 +395,10 @@ func (c *gtsCaches) initUser() {
return u2
}, config.GetCacheGTSUserMaxSize())
c.user.SetTTL(config.GetCacheGTSUserTTL(), true)
c.user.IgnoreErrors(ignoreErrors)
}
func (c *gtsCaches) initWebfinger() {
func (c *GTSCaches) initWebfinger() {
c.webfinger = ttl.New[string, string](
0,
config.GetCacheGTSWebfingerMaxSize(),

View File

@ -17,7 +17,30 @@
package cache
import "github.com/superseriousbusiness/gotosocial/internal/log"
import (
"context"
"errors"
"fmt"
"time"
"codeberg.org/gruf/go-cache/v3/result"
errorsv2 "codeberg.org/gruf/go-errors/v2"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// SentinelError is returned to indicate a non-permanent error return,
// i.e. a situation in which we do not want a cache a negative result.
var SentinelError = errors.New("BUG: error should not be returned") //nolint:revive
// ignoreErrors is an error ignoring function capable of being passed to
// caches, which specifically catches and ignores our sentinel error type.
func ignoreErrors(err error) bool {
return errorsv2.Is(
SentinelError,
context.DeadlineExceeded,
context.Canceled,
)
}
// nocopy when embedded will signal linter to
// error on pass-by-value of parent struct.
@ -27,6 +50,26 @@ func (*nocopy) Lock() {}
func (*nocopy) Unlock() {}
// tryStart will attempt to start the given cache only if sweep duration > 0 (sweeping is enabled).
func tryStart[ValueType any](cache *result.Cache[ValueType], sweep time.Duration) {
if sweep > 0 {
var z ValueType
msg := fmt.Sprintf("starting %T cache", z)
tryUntil(msg, 5, func() bool {
return cache.Start(sweep)
})
}
}
// tryStop will attempt to stop the given cache only if sweep duration > 0 (sweeping is enabled).
func tryStop[ValueType any](cache *result.Cache[ValueType], sweep time.Duration) {
if sweep > 0 {
var z ValueType
msg := fmt.Sprintf("stopping %T cache", z)
tryUntil(msg, 5, cache.Stop)
}
}
// tryUntil will attempt to call 'do' for 'count' attempts, before panicking with 'msg'.
func tryUntil(msg string, count int, do func() bool) {
for i := 0; i < count; i++ {

81
internal/cache/visibility.go vendored Normal file
View File

@ -0,0 +1,81 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package cache
import (
"codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
type VisibilityCache struct {
*result.Cache[*CachedVisibility]
}
// Init will initialize the visibility cache in this collection.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *VisibilityCache) Init() {
c.Cache = result.New([]result.Lookup{
{Name: "ItemID"},
{Name: "RequesterID"},
{Name: "Type.RequesterID.ItemID"},
}, func(v1 *CachedVisibility) *CachedVisibility {
v2 := new(CachedVisibility)
*v2 = *v1
return v2
}, config.GetCacheVisibilityMaxSize())
c.Cache.SetTTL(config.GetCacheVisibilityTTL(), true)
c.Cache.IgnoreErrors(ignoreErrors)
}
// Start will attempt to start the visibility cache, or panic.
func (c *VisibilityCache) Start() {
tryStart(c.Cache, config.GetCacheVisibilitySweepFreq())
}
// Stop will attempt to stop the visibility cache, or panic.
func (c *VisibilityCache) Stop() {
tryStop(c.Cache, config.GetCacheVisibilitySweepFreq())
}
// VisibilityType represents a visibility lookup type.
// We use a byte type here to improve performance in the
// result cache when generating the key.
type VisibilityType byte
const (
// Possible cache visibility lookup types.
VisibilityTypeAccount = VisibilityType('a')
VisibilityTypeStatus = VisibilityType('s')
VisibilityTypeHome = VisibilityType('h')
VisibilityTypePublic = VisibilityType('p')
)
// CachedVisibility represents a cached visibility lookup value.
type CachedVisibility struct {
// ItemID is the ID of the item in question (status / account).
ItemID string
// RequesterID is the ID of the requesting account for this visibility lookup.
RequesterID string
// Type is the visibility lookup type.
Type VisibilityType
// Value is the actual visibility value.
Value bool
}

View File

@ -157,6 +157,10 @@ type Configuration struct {
type CacheConfiguration struct {
GTS GTSCacheConfiguration `name:"gts"`
VisibilityMaxSize int `name:"visibility-max-size"`
VisibilityTTL time.Duration `name:"visibility-ttl"`
VisibilitySweepFreq time.Duration `name:"visibility-sweep-freq"`
}
type GTSCacheConfiguration struct {
@ -180,6 +184,14 @@ type GTSCacheConfiguration struct {
EmojiCategoryTTL time.Duration `name:"emoji-category-ttl"`
EmojiCategorySweepFreq time.Duration `name:"emoji-category-sweep-freq"`
FollowMaxSize int `name:"follow-max-size"`
FollowTTL time.Duration `name:"follow-ttl"`
FollowSweepFreq time.Duration `name:"follow-sweep-freq"`
FollowRequestMaxSize int `name:"follow-request-max-size"`
FollowRequestTTL time.Duration `name:"follow-request-ttl"`
FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"`
MediaMaxSize int `name:"media-max-size"`
MediaTTL time.Duration `name:"media-ttl"`
MediaSweepFreq time.Duration `name:"media-sweep-freq"`
@ -200,6 +212,10 @@ type GTSCacheConfiguration struct {
StatusTTL time.Duration `name:"status-ttl"`
StatusSweepFreq time.Duration `name:"status-sweep-freq"`
StatusFaveMaxSize int `name:"status-fave-max-size"`
StatusFaveTTL time.Duration `name:"status-fave-ttl"`
StatusFaveSweepFreq time.Duration `name:"status-fave-sweep-freq"`
TombstoneMaxSize int `name:"tombstone-max-size"`
TombstoneTTL time.Duration `name:"tombstone-ttl"`
TombstoneSweepFreq time.Duration `name:"tombstone-sweep-freq"`

View File

@ -51,7 +51,7 @@ var Defaults = Configuration{
DbSqliteJournalMode: "WAL",
DbSqliteSynchronous: "NORMAL",
DbSqliteCacheSize: 8 * bytesize.MiB,
DbSqliteBusyTimeout: time.Minute * 5,
DbSqliteBusyTimeout: time.Minute * 30,
WebTemplateBaseDir: "./web/template/",
WebAssetBaseDir: "./web/assets/",
@ -119,58 +119,74 @@ var Defaults = Configuration{
Cache: CacheConfiguration{
GTS: GTSCacheConfiguration{
AccountMaxSize: 500,
AccountTTL: time.Minute * 5,
AccountSweepFreq: time.Second * 30,
AccountMaxSize: 2000,
AccountTTL: time.Minute * 30,
AccountSweepFreq: time.Minute,
BlockMaxSize: 100,
BlockTTL: time.Minute * 5,
BlockSweepFreq: time.Second * 30,
BlockMaxSize: 1000,
BlockTTL: time.Minute * 30,
BlockSweepFreq: time.Minute,
DomainBlockMaxSize: 1000,
DomainBlockMaxSize: 2000,
DomainBlockTTL: time.Hour * 24,
DomainBlockSweepFreq: time.Minute,
EmojiMaxSize: 500,
EmojiTTL: time.Minute * 5,
EmojiSweepFreq: time.Second * 30,
EmojiMaxSize: 2000,
EmojiTTL: time.Minute * 30,
EmojiSweepFreq: time.Minute,
EmojiCategoryMaxSize: 100,
EmojiCategoryTTL: time.Minute * 5,
EmojiCategorySweepFreq: time.Second * 30,
EmojiCategoryTTL: time.Minute * 30,
EmojiCategorySweepFreq: time.Minute,
MediaMaxSize: 500,
MediaTTL: time.Minute * 5,
MediaSweepFreq: time.Second * 30,
FollowMaxSize: 2000,
FollowTTL: time.Minute * 30,
FollowSweepFreq: time.Minute,
MentionMaxSize: 500,
MentionTTL: time.Minute * 5,
MentionSweepFreq: time.Second * 30,
FollowRequestMaxSize: 2000,
FollowRequestTTL: time.Minute * 30,
FollowRequestSweepFreq: time.Minute,
NotificationMaxSize: 500,
NotificationTTL: time.Minute * 5,
NotificationSweepFreq: time.Second * 30,
MediaMaxSize: 1000,
MediaTTL: time.Minute * 30,
MediaSweepFreq: time.Minute,
MentionMaxSize: 2000,
MentionTTL: time.Minute * 30,
MentionSweepFreq: time.Minute,
NotificationMaxSize: 1000,
NotificationTTL: time.Minute * 30,
NotificationSweepFreq: time.Minute,
ReportMaxSize: 100,
ReportTTL: time.Minute * 5,
ReportSweepFreq: time.Second * 30,
ReportTTL: time.Minute * 30,
ReportSweepFreq: time.Minute,
StatusMaxSize: 500,
StatusTTL: time.Minute * 5,
StatusSweepFreq: time.Second * 30,
StatusMaxSize: 2000,
StatusTTL: time.Minute * 30,
StatusSweepFreq: time.Minute,
TombstoneMaxSize: 100,
TombstoneTTL: time.Minute * 5,
TombstoneSweepFreq: time.Second * 30,
StatusFaveMaxSize: 2000,
StatusFaveTTL: time.Minute * 30,
StatusFaveSweepFreq: time.Minute,
UserMaxSize: 100,
UserTTL: time.Minute * 5,
UserSweepFreq: time.Second * 30,
TombstoneMaxSize: 500,
TombstoneTTL: time.Minute * 30,
TombstoneSweepFreq: time.Minute,
UserMaxSize: 500,
UserTTL: time.Minute * 30,
UserSweepFreq: time.Minute,
WebfingerMaxSize: 250,
WebfingerTTL: time.Hour * 24,
WebfingerSweepFreq: time.Minute * 15,
},
VisibilityMaxSize: 2000,
VisibilityTTL: time.Minute * 30,
VisibilitySweepFreq: time.Minute,
},
AdminMediaPruneDryRun: true,

View File

@ -2501,6 +2501,158 @@ func GetCacheGTSEmojiCategorySweepFreq() time.Duration {
// SetCacheGTSEmojiCategorySweepFreq safely sets the value for global configuration 'Cache.GTS.EmojiCategorySweepFreq' field
func SetCacheGTSEmojiCategorySweepFreq(v time.Duration) { global.SetCacheGTSEmojiCategorySweepFreq(v) }
// GetCacheGTSFollowMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowMaxSize' field
func (st *ConfigState) GetCacheGTSFollowMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.FollowMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSFollowMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowMaxSize' field
func (st *ConfigState) SetCacheGTSFollowMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.FollowMaxSize = v
st.reloadToViper()
}
// CacheGTSFollowMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowMaxSize' field
func CacheGTSFollowMaxSizeFlag() string { return "cache-gts-follow-max-size" }
// GetCacheGTSFollowMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowMaxSize' field
func GetCacheGTSFollowMaxSize() int { return global.GetCacheGTSFollowMaxSize() }
// SetCacheGTSFollowMaxSize safely sets the value for global configuration 'Cache.GTS.FollowMaxSize' field
func SetCacheGTSFollowMaxSize(v int) { global.SetCacheGTSFollowMaxSize(v) }
// GetCacheGTSFollowTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowTTL' field
func (st *ConfigState) GetCacheGTSFollowTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.FollowTTL
st.mutex.Unlock()
return
}
// SetCacheGTSFollowTTL safely sets the Configuration value for state's 'Cache.GTS.FollowTTL' field
func (st *ConfigState) SetCacheGTSFollowTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.FollowTTL = v
st.reloadToViper()
}
// CacheGTSFollowTTLFlag returns the flag name for the 'Cache.GTS.FollowTTL' field
func CacheGTSFollowTTLFlag() string { return "cache-gts-follow-ttl" }
// GetCacheGTSFollowTTL safely fetches the value for global configuration 'Cache.GTS.FollowTTL' field
func GetCacheGTSFollowTTL() time.Duration { return global.GetCacheGTSFollowTTL() }
// SetCacheGTSFollowTTL safely sets the value for global configuration 'Cache.GTS.FollowTTL' field
func SetCacheGTSFollowTTL(v time.Duration) { global.SetCacheGTSFollowTTL(v) }
// GetCacheGTSFollowSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowSweepFreq' field
func (st *ConfigState) GetCacheGTSFollowSweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.FollowSweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSFollowSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowSweepFreq' field
func (st *ConfigState) SetCacheGTSFollowSweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.FollowSweepFreq = v
st.reloadToViper()
}
// CacheGTSFollowSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowSweepFreq' field
func CacheGTSFollowSweepFreqFlag() string { return "cache-gts-follow-sweep-freq" }
// GetCacheGTSFollowSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowSweepFreq' field
func GetCacheGTSFollowSweepFreq() time.Duration { return global.GetCacheGTSFollowSweepFreq() }
// SetCacheGTSFollowSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowSweepFreq' field
func SetCacheGTSFollowSweepFreq(v time.Duration) { global.SetCacheGTSFollowSweepFreq(v) }
// GetCacheGTSFollowRequestMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestMaxSize' field
func (st *ConfigState) GetCacheGTSFollowRequestMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.FollowRequestMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSFollowRequestMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowRequestMaxSize' field
func (st *ConfigState) SetCacheGTSFollowRequestMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.FollowRequestMaxSize = v
st.reloadToViper()
}
// CacheGTSFollowRequestMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowRequestMaxSize' field
func CacheGTSFollowRequestMaxSizeFlag() string { return "cache-gts-follow-request-max-size" }
// GetCacheGTSFollowRequestMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowRequestMaxSize' field
func GetCacheGTSFollowRequestMaxSize() int { return global.GetCacheGTSFollowRequestMaxSize() }
// SetCacheGTSFollowRequestMaxSize safely sets the value for global configuration 'Cache.GTS.FollowRequestMaxSize' field
func SetCacheGTSFollowRequestMaxSize(v int) { global.SetCacheGTSFollowRequestMaxSize(v) }
// GetCacheGTSFollowRequestTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestTTL' field
func (st *ConfigState) GetCacheGTSFollowRequestTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.FollowRequestTTL
st.mutex.Unlock()
return
}
// SetCacheGTSFollowRequestTTL safely sets the Configuration value for state's 'Cache.GTS.FollowRequestTTL' field
func (st *ConfigState) SetCacheGTSFollowRequestTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.FollowRequestTTL = v
st.reloadToViper()
}
// CacheGTSFollowRequestTTLFlag returns the flag name for the 'Cache.GTS.FollowRequestTTL' field
func CacheGTSFollowRequestTTLFlag() string { return "cache-gts-follow-request-ttl" }
// GetCacheGTSFollowRequestTTL safely fetches the value for global configuration 'Cache.GTS.FollowRequestTTL' field
func GetCacheGTSFollowRequestTTL() time.Duration { return global.GetCacheGTSFollowRequestTTL() }
// SetCacheGTSFollowRequestTTL safely sets the value for global configuration 'Cache.GTS.FollowRequestTTL' field
func SetCacheGTSFollowRequestTTL(v time.Duration) { global.SetCacheGTSFollowRequestTTL(v) }
// GetCacheGTSFollowRequestSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestSweepFreq' field
func (st *ConfigState) GetCacheGTSFollowRequestSweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.FollowRequestSweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSFollowRequestSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowRequestSweepFreq' field
func (st *ConfigState) SetCacheGTSFollowRequestSweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.FollowRequestSweepFreq = v
st.reloadToViper()
}
// CacheGTSFollowRequestSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowRequestSweepFreq' field
func CacheGTSFollowRequestSweepFreqFlag() string { return "cache-gts-follow-request-sweep-freq" }
// GetCacheGTSFollowRequestSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
func GetCacheGTSFollowRequestSweepFreq() time.Duration {
return global.GetCacheGTSFollowRequestSweepFreq()
}
// SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) }
// GetCacheGTSMediaMaxSize safely fetches the Configuration value for state's 'Cache.GTS.MediaMaxSize' field
func (st *ConfigState) GetCacheGTSMediaMaxSize() (v int) {
st.mutex.Lock()
@ -2878,6 +3030,81 @@ func GetCacheGTSStatusSweepFreq() time.Duration { return global.GetCacheGTSStatu
// SetCacheGTSStatusSweepFreq safely sets the value for global configuration 'Cache.GTS.StatusSweepFreq' field
func SetCacheGTSStatusSweepFreq(v time.Duration) { global.SetCacheGTSStatusSweepFreq(v) }
// GetCacheGTSStatusFaveMaxSize safely fetches the Configuration value for state's 'Cache.GTS.StatusFaveMaxSize' field
func (st *ConfigState) GetCacheGTSStatusFaveMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.StatusFaveMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSStatusFaveMaxSize safely sets the Configuration value for state's 'Cache.GTS.StatusFaveMaxSize' field
func (st *ConfigState) SetCacheGTSStatusFaveMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.StatusFaveMaxSize = v
st.reloadToViper()
}
// CacheGTSStatusFaveMaxSizeFlag returns the flag name for the 'Cache.GTS.StatusFaveMaxSize' field
func CacheGTSStatusFaveMaxSizeFlag() string { return "cache-gts-status-fave-max-size" }
// GetCacheGTSStatusFaveMaxSize safely fetches the value for global configuration 'Cache.GTS.StatusFaveMaxSize' field
func GetCacheGTSStatusFaveMaxSize() int { return global.GetCacheGTSStatusFaveMaxSize() }
// SetCacheGTSStatusFaveMaxSize safely sets the value for global configuration 'Cache.GTS.StatusFaveMaxSize' field
func SetCacheGTSStatusFaveMaxSize(v int) { global.SetCacheGTSStatusFaveMaxSize(v) }
// GetCacheGTSStatusFaveTTL safely fetches the Configuration value for state's 'Cache.GTS.StatusFaveTTL' field
func (st *ConfigState) GetCacheGTSStatusFaveTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.StatusFaveTTL
st.mutex.Unlock()
return
}
// SetCacheGTSStatusFaveTTL safely sets the Configuration value for state's 'Cache.GTS.StatusFaveTTL' field
func (st *ConfigState) SetCacheGTSStatusFaveTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.StatusFaveTTL = v
st.reloadToViper()
}
// CacheGTSStatusFaveTTLFlag returns the flag name for the 'Cache.GTS.StatusFaveTTL' field
func CacheGTSStatusFaveTTLFlag() string { return "cache-gts-status-fave-ttl" }
// GetCacheGTSStatusFaveTTL safely fetches the value for global configuration 'Cache.GTS.StatusFaveTTL' field
func GetCacheGTSStatusFaveTTL() time.Duration { return global.GetCacheGTSStatusFaveTTL() }
// SetCacheGTSStatusFaveTTL safely sets the value for global configuration 'Cache.GTS.StatusFaveTTL' field
func SetCacheGTSStatusFaveTTL(v time.Duration) { global.SetCacheGTSStatusFaveTTL(v) }
// GetCacheGTSStatusFaveSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.StatusFaveSweepFreq' field
func (st *ConfigState) GetCacheGTSStatusFaveSweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.StatusFaveSweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSStatusFaveSweepFreq safely sets the Configuration value for state's 'Cache.GTS.StatusFaveSweepFreq' field
func (st *ConfigState) SetCacheGTSStatusFaveSweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.StatusFaveSweepFreq = v
st.reloadToViper()
}
// CacheGTSStatusFaveSweepFreqFlag returns the flag name for the 'Cache.GTS.StatusFaveSweepFreq' field
func CacheGTSStatusFaveSweepFreqFlag() string { return "cache-gts-status-fave-sweep-freq" }
// GetCacheGTSStatusFaveSweepFreq safely fetches the value for global configuration 'Cache.GTS.StatusFaveSweepFreq' field
func GetCacheGTSStatusFaveSweepFreq() time.Duration { return global.GetCacheGTSStatusFaveSweepFreq() }
// SetCacheGTSStatusFaveSweepFreq safely sets the value for global configuration 'Cache.GTS.StatusFaveSweepFreq' field
func SetCacheGTSStatusFaveSweepFreq(v time.Duration) { global.SetCacheGTSStatusFaveSweepFreq(v) }
// GetCacheGTSTombstoneMaxSize safely fetches the Configuration value for state's 'Cache.GTS.TombstoneMaxSize' field
func (st *ConfigState) GetCacheGTSTombstoneMaxSize() (v int) {
st.mutex.Lock()
@ -3103,6 +3330,81 @@ func GetCacheGTSWebfingerSweepFreq() time.Duration { return global.GetCacheGTSWe
// SetCacheGTSWebfingerSweepFreq safely sets the value for global configuration 'Cache.GTS.WebfingerSweepFreq' field
func SetCacheGTSWebfingerSweepFreq(v time.Duration) { global.SetCacheGTSWebfingerSweepFreq(v) }
// GetCacheVisibilityMaxSize safely fetches the Configuration value for state's 'Cache.VisibilityMaxSize' field
func (st *ConfigState) GetCacheVisibilityMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.VisibilityMaxSize
st.mutex.Unlock()
return
}
// SetCacheVisibilityMaxSize safely sets the Configuration value for state's 'Cache.VisibilityMaxSize' field
func (st *ConfigState) SetCacheVisibilityMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.VisibilityMaxSize = v
st.reloadToViper()
}
// CacheVisibilityMaxSizeFlag returns the flag name for the 'Cache.VisibilityMaxSize' field
func CacheVisibilityMaxSizeFlag() string { return "cache-visibility-max-size" }
// GetCacheVisibilityMaxSize safely fetches the value for global configuration 'Cache.VisibilityMaxSize' field
func GetCacheVisibilityMaxSize() int { return global.GetCacheVisibilityMaxSize() }
// SetCacheVisibilityMaxSize safely sets the value for global configuration 'Cache.VisibilityMaxSize' field
func SetCacheVisibilityMaxSize(v int) { global.SetCacheVisibilityMaxSize(v) }
// GetCacheVisibilityTTL safely fetches the Configuration value for state's 'Cache.VisibilityTTL' field
func (st *ConfigState) GetCacheVisibilityTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.VisibilityTTL
st.mutex.Unlock()
return
}
// SetCacheVisibilityTTL safely sets the Configuration value for state's 'Cache.VisibilityTTL' field
func (st *ConfigState) SetCacheVisibilityTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.VisibilityTTL = v
st.reloadToViper()
}
// CacheVisibilityTTLFlag returns the flag name for the 'Cache.VisibilityTTL' field
func CacheVisibilityTTLFlag() string { return "cache-visibility-ttl" }
// GetCacheVisibilityTTL safely fetches the value for global configuration 'Cache.VisibilityTTL' field
func GetCacheVisibilityTTL() time.Duration { return global.GetCacheVisibilityTTL() }
// SetCacheVisibilityTTL safely sets the value for global configuration 'Cache.VisibilityTTL' field
func SetCacheVisibilityTTL(v time.Duration) { global.SetCacheVisibilityTTL(v) }
// GetCacheVisibilitySweepFreq safely fetches the Configuration value for state's 'Cache.VisibilitySweepFreq' field
func (st *ConfigState) GetCacheVisibilitySweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.VisibilitySweepFreq
st.mutex.Unlock()
return
}
// SetCacheVisibilitySweepFreq safely sets the Configuration value for state's 'Cache.VisibilitySweepFreq' field
func (st *ConfigState) SetCacheVisibilitySweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.VisibilitySweepFreq = v
st.reloadToViper()
}
// CacheVisibilitySweepFreqFlag returns the flag name for the 'Cache.VisibilitySweepFreq' field
func CacheVisibilitySweepFreqFlag() string { return "cache-visibility-sweep-freq" }
// GetCacheVisibilitySweepFreq safely fetches the value for global configuration 'Cache.VisibilitySweepFreq' field
func GetCacheVisibilitySweepFreq() time.Duration { return global.GetCacheVisibilitySweepFreq() }
// SetCacheVisibilitySweepFreq safely sets the value for global configuration 'Cache.VisibilitySweepFreq' field
func SetCacheVisibilitySweepFreq(v time.Duration) { global.SetCacheVisibilitySweepFreq(v) }
// GetAdminAccountUsername safely fetches the Configuration value for state's 'AdminAccountUsername' field
func (st *ConfigState) GetAdminAccountUsername() (v string) {
st.mutex.Lock()

View File

@ -41,6 +41,21 @@ type Account interface {
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong.
GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong.
GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong.
GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
// PutAccount puts one account in the database.
PutAccount(ctx context.Context, account *gtsmodel.Account) Error

View File

@ -20,11 +20,13 @@ package bundb
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -37,18 +39,15 @@ type accountDB struct {
state *state.State
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
return a.conn.
NewSelect().
Model(account)
}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"ID",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.id"), id).
Scan(ctx)
},
id,
)
@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
ctx,
"URI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.uri"), uri).
Scan(ctx)
},
uri,
)
@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
ctx,
"URL",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.url"), url).
Scan(ctx)
},
url,
)
@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx,
"Username.Domain",
func(account *gtsmodel.Account) error {
q := a.newAccountQ(account)
q := a.conn.NewSelect().
Model(account)
if domain != "" {
q = q.
@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
ctx,
"PublicKeyURI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.public_key_uri"), id).
Scan(ctx)
},
id,
)
}
func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"InboxURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.inbox_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"OutboxURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.outbox_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"FollowersURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.followers_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"FollowingURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.following_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
var username string
@ -141,33 +206,58 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(
return nil, err
}
if account.AvatarMediaAttachmentID != "" {
// Set the account's related avatar
account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
if err != nil {
log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err)
}
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return account, nil
}
if account.HeaderMediaAttachmentID != "" {
// Set the account's related header
account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID)
if err != nil {
log.Errorf(ctx, "error getting account %s header: %v", account.ID, err)
}
}
if len(account.EmojiIDs) > 0 {
// Set the account's related emojis
account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs)
if err != nil {
log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err)
}
// Further populate the account fields where applicable.
if err := a.PopulateAccount(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error {
var err error
if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" {
// Account avatar attachment is not set, fetch from database.
account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(
ctx, // these are already barebones
account.AvatarMediaAttachmentID,
)
if err != nil {
return fmt.Errorf("error populating account avatar: %w", err)
}
}
if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" {
// Account header attachment is not set, fetch from database.
account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(
ctx, // these are already barebones
account.HeaderMediaAttachmentID,
)
if err != nil {
return fmt.Errorf("error populating account header: %w", err)
}
}
if !account.EmojisPopulated() {
// Account emojis are out-of-date with IDs, repopulate.
account.Emojis, err = a.state.DB.GetEmojisByIDs(
ctx, // these are already barebones
account.EmojiIDs,
)
if err != nil {
return fmt.Errorf("error populating account emojis: %w", err)
}
}
return nil
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
return a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at")
}
return a.state.Caches.GTS.Account().Store(account, func() error {
err := a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
return err
})
})
if err != nil {
return err
}
return nil
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err
}
// Invalidate account from database lookups.
a.state.Caches.GTS.Account().Invalidate("ID", id)
return nil
}

View File

@ -21,6 +21,8 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"reflect"
"strings"
"testing"
"time"
@ -61,44 +63,149 @@ func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {
suite.Len(statuses, 1)
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
func (suite *AccountTestSuite) TestGetAccountBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 account models are equal.
isEqual := func(a1, a2 gtsmodel.Account) bool {
// Clear populated sub-models.
a1.HeaderMediaAttachment = nil
a2.HeaderMediaAttachment = nil
a1.AvatarMediaAttachment = nil
a2.AvatarMediaAttachment = nil
a1.Emojis = nil
a2.Emojis = nil
// Clear database-set fields.
a1.CreatedAt = time.Time{}
a2.CreatedAt = time.Time{}
a1.UpdatedAt = time.Time{}
a2.UpdatedAt = time.Time{}
// Manually compare keys.
pk1 := a1.PublicKey
pv1 := a1.PrivateKey
pk2 := a2.PublicKey
pv2 := a2.PrivateKey
a1.PublicKey = nil
a1.PrivateKey = nil
a2.PublicKey = nil
a2.PrivateKey = nil
return reflect.DeepEqual(a1, a2) &&
((pk1 == nil && pk2 == nil) || pk1.Equal(pk2)) &&
((pv1 == nil && pv2 == nil) || pv1.Equal(pv2))
}
for _, account := range suite.testAccounts {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Account, error){
"id": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByID(ctx, account.ID)
},
"uri": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByURI(ctx, account.URI)
},
"url": func() (*gtsmodel.Account, error) {
if account.URL == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByURL(ctx, account.URL)
},
"username@domain": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByUsernameDomain(ctx, account.Username, account.Domain)
},
"username_upper@domain": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByUsernameDomain(ctx, strings.ToUpper(account.Username), account.Domain)
},
"username_lower@domain": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByUsernameDomain(ctx, strings.ToLower(account.Username), account.Domain)
},
"public_key_uri": func() (*gtsmodel.Account, error) {
if account.PublicKeyURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByPubkeyID(ctx, account.PublicKeyURI)
},
"inbox_uri": func() (*gtsmodel.Account, error) {
if account.InboxURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByInboxURI(ctx, account.InboxURI)
},
"outbox_uri": func() (*gtsmodel.Account, error) {
if account.OutboxURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByOutboxURI(ctx, account.OutboxURI)
},
"following_uri": func() (*gtsmodel.Account, error) {
if account.FollowingURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByFollowingURI(ctx, account.FollowingURI)
},
"followers_uri": func() (*gtsmodel.Account, error) {
if account.FollowersURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByFollowersURI(ctx, account.FollowersURI)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkAcc, err := dbfunc()
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(account)
suite.NotNil(account.AvatarMediaAttachment)
suite.NotEmpty(account.AvatarMediaAttachment.URL)
suite.NotNil(account.HeaderMediaAttachment)
suite.NotEmpty(account.HeaderMediaAttachment.URL)
if err == sentinelErr {
continue
}
func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
testAccount1 := suite.testAccounts["local_account_1"]
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
suite.NoError(err)
suite.NotNil(account1)
testAccount2 := suite.testAccounts["remote_account_1"]
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
suite.NoError(err)
suite.NotNil(account2)
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() {
testAccount := suite.testAccounts["remote_account_2"]
// Check received account data.
if !isEqual(*checkAcc, *account) {
t.Errorf("account does not contain expected data: %+v", checkAcc)
continue
}
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain)
suite.NoError(err)
suite.NotNil(account1)
// Check that avatar attachment populated.
if account.AvatarMediaAttachmentID != "" &&
(checkAcc.AvatarMediaAttachment == nil || checkAcc.AvatarMediaAttachment.ID != account.AvatarMediaAttachmentID) {
t.Errorf("account avatar media attachment not correctly populated for: %+v", account)
continue
}
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain)
suite.NoError(err)
suite.NotNil(account2)
account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain)
suite.NoError(err)
suite.NotNil(account3)
// Check that header attachment populated.
if account.HeaderMediaAttachmentID != "" &&
(checkAcc.HeaderMediaAttachment == nil || checkAcc.HeaderMediaAttachment.ID != account.HeaderMediaAttachmentID) {
t.Errorf("account header media attachment not correctly populated for: %+v", account)
continue
}
}
}
}
func (suite *AccountTestSuite) TestUpdateAccount() {

View File

@ -19,6 +19,8 @@ package bundb_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"testing"
"time"
@ -40,6 +42,12 @@ func (suite *BasicTestSuite) TestGetAccountByID() {
}
func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
suite.FailNow(err.Error())
}
// Create an account that only just matches constraints.
testAccount := &gtsmodel.Account{
ID: "01GADR1AH9VCKH8YYCM86XSZ00",
Username: "test",
@ -49,6 +57,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
OutboxURI: "https://example.org/users/test/outbox",
ActorType: "Person",
PublicKeyURI: "https://example.org/test#main-key",
PublicKey: &key.PublicKey,
}
if err := suite.db.Put(context.Background(), testAccount); err != nil {
@ -99,7 +108,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
suite.Empty(a.FeaturedCollectionURI)
suite.Equal(testAccount.ActorType, a.ActorType)
suite.Nil(a.PrivateKey)
suite.Nil(a.PublicKey)
suite.EqualValues(key.PublicKey, *a.PublicKey)
suite.Equal(testAccount.PublicKeyURI, a.PublicKeyURI)
suite.Zero(a.SensitizedAt)
suite.Zero(a.SilencedAt)

View File

@ -47,6 +47,24 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
)
}
func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
for _, id := range ids {
// Attempt fetch from DB
attachment, err := m.GetAttachmentByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting attachment %q: %v", id, err)
continue
}
// Append attachment
attachments = append(attachments, attachment)
}
return attachments, nil
}
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment
@ -118,7 +136,7 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
return nil, m.conn.ProcessError(err)
}
return m.getAttachments(ctx, attachmentIDs)
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@ -163,7 +181,7 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
return nil, m.conn.ProcessError(err)
}
return m.getAttachments(ctx, attachmentIDs)
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
@ -189,7 +207,7 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
return nil, m.conn.ProcessError(err)
}
return m.getAttachments(ctx, attachmentIDs)
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@ -211,21 +229,3 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t
return count, nil
}
func (m *mediaDB) getAttachments(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, db.Error) {
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
for _, id := range ids {
// Attempt fetch from DB
attachment, err := m.GetAttachmentByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting attachment %q: %v", id, err)
continue
}
// Append attachment
attachments = append(attachments, attachment)
}
return attachments, nil
}

View File

@ -19,8 +19,10 @@ package bundb
import (
"context"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -32,20 +34,13 @@ type mentionDB struct {
state *state.State
}
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
return m.conn.
NewSelect().
Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
return m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
q := m.newMentionQ(&mention).
q := m.conn.
NewSelect().
Model(&mention).
Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil {
@ -54,6 +49,38 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return &mention, nil
}, id)
if err != nil {
return nil, err
}
// Set the mention originating status.
mention.Status, err = m.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
mention.StatusID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention status: %w", err)
}
// Set the mention origin account model.
mention.OriginAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.OriginAccountID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention origin account: %w", err)
}
// Set the mention target account model.
mention.TargetAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention target account: %w", err)
}
return mention, nil
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
@ -73,3 +100,25 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
return mentions, nil
}
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error {
_, err := m.conn.NewInsert().Model(mention).Exec(ctx)
return m.conn.ProcessError(err)
})
}
func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
if _, err := m.conn.
NewDelete().
Table("mentions").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return m.conn.ProcessError(err)
}
// Invalidate mention from the lookup cache.
m.state.Caches.GTS.Mention().Invalidate("ID", id)
return nil
}

View File

@ -0,0 +1,167 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
// To update unique constraint on public key, we need to migrate accounts into a new table.
// See section 7 here: https://www.sqlite.org/lang_altertable.html
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Create the new accounts table.
if _, err := tx.
NewCreateTable().
ModelTableExpr("new_accounts").
Model(&gtsmodel.Account{}).
Exec(ctx); err != nil {
return err
}
// If we don't specify columns explicitly,
// Postgres gives the following error when
// transferring accounts to new_accounts:
//
// ERROR: column "fetched_at" is of type timestamp with time zone but expression is of type character varying at character 35
// HINT: You will need to rewrite or cast the expression.
//
// Rather than do funky casting to fix this,
// it's simpler to just specify all columns.
columns := []string{
"id",
"created_at",
"updated_at",
"fetched_at",
"username",
"domain",
"avatar_media_attachment_id",
"avatar_remote_url",
"header_media_attachment_id",
"header_remote_url",
"display_name",
"emojis",
"fields",
"note",
"note_raw",
"memorial",
"also_known_as",
"moved_to_account_id",
"bot",
"reason",
"locked",
"discoverable",
"privacy",
"sensitive",
"language",
"status_content_type",
"custom_css",
"uri",
"url",
"inbox_uri",
"shared_inbox_uri",
"outbox_uri",
"following_uri",
"followers_uri",
"featured_collection_uri",
"actor_type",
"private_key",
"public_key",
"public_key_uri",
"sensitized_at",
"silenced_at",
"suspended_at",
"hide_collections",
"suspension_origin",
"enable_rss",
}
// Copy all accounts to the new table.
if _, err := tx.
NewInsert().
Table("new_accounts").
Table("accounts").
Column(columns...).
Exec(ctx); err != nil {
return err
}
// Drop the old table.
if _, err := tx.
NewDropTable().
Table("accounts").
Exec(ctx); err != nil {
return err
}
// Rename new table to old table.
if _, err := tx.
ExecContext(
ctx,
"ALTER TABLE ? RENAME TO ?",
bun.Ident("new_accounts"),
bun.Ident("accounts"),
); err != nil {
return err
}
// Add all account indexes to the new table.
for index, columns := range map[string][]string{
// Standard indices.
"accounts_id_idx": {"id"},
"accounts_suspended_at_idx": {"suspended_at"},
"accounts_domain_idx": {"domain"},
"accounts_username_domain_idx": {"username", "domain"},
// URI indices.
"accounts_uri_idx": {"uri"},
"accounts_url_idx": {"url"},
"accounts_inbox_uri_idx": {"inbox_uri"},
"accounts_outbox_uri_idx": {"outbox_uri"},
"accounts_followers_uri_idx": {"followers_uri"},
"accounts_following_uri_idx": {"following_uri"},
"accounts_public_key_uri_idx": {"public_key_uri"},
} {
if _, err := tx.
NewCreateIndex().
Table("accounts").
Index(index).
Column(columns...).
Exec(ctx); err != nil {
return err
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View File

@ -33,7 +33,7 @@ type notificationDB struct {
state *state.State
}
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
@ -48,7 +48,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
}, id)
}
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
func (n *notificationDB) GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -92,7 +92,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
// reason for this is that for each notif, we can instead get it from our cache if it's cached
for _, id := range notifIDs {
// Attempt fetch from DB
notif, err := n.GetNotification(ctx, id)
notif, err := n.GetNotificationByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting notification %q: %v", id, err)
continue
@ -105,7 +105,14 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
return notifs, nil
}
func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.Error {
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
return n.state.Caches.GTS.Notification().Store(notif, func() error {
_, err := n.conn.NewInsert().Model(notif).Exec(ctx)
return n.conn.ProcessError(err)
})
}
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error {
if _, err := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
@ -118,19 +125,23 @@ func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.E
return nil
}
func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error {
if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
}
// Capture notification IDs in a RETURNING statement.
ids := []string{}
var ids []string
q := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Returning("?", bun.Ident("id"))
if len(types) > 0 {
q = q.Where("? IN (?)", bun.Ident("notification.notification_type"), bun.In(types))
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID)
}
@ -153,7 +164,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountI
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error {
// Capture notification IDs in a RETURNING statement.
ids := []string{}
var ids []string
q := n.conn.
NewDelete().

View File

@ -85,11 +85,11 @@ type NotificationTestSuite struct {
BunDBStandardTestSuite
}
func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@ -100,10 +100,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
}
}
func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() {
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withoutSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@ -117,10 +117,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@ -129,10 +129,10 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@ -146,7 +146,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAccount() {
testAccount := suite.testAccounts["local_account_2"]
if err := suite.db.DeleteNotifications(context.Background(), "", testAccount.ID); err != nil {
if err := suite.db.DeleteNotifications(context.Background(), nil, "", testAccount.ID); err != nil {
suite.FailNow(err.Error())
}
@ -166,7 +166,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAndTar
originAccount := suite.testAccounts["local_account_2"]
targetAccount := suite.testAccounts["admin_account"]
if err := suite.db.DeleteNotifications(context.Background(), targetAccount.ID, originAccount.ID); err != nil {
if err := suite.db.DeleteNotifications(context.Background(), nil, targetAccount.ID, originAccount.ID); err != nil {
suite.FailNow(err.Error())
}

View File

@ -23,8 +23,8 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@ -34,603 +34,212 @@ type relationshipDB struct {
state *state.State
}
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
// Look for a block in direction of account1->account2
block1, err := r.getBlock(ctx, account1, account2)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
if block1 != nil {
// account1 blocks account2
return true, nil
} else if !eitherDirection {
// Don't check for mutli-directional
return false, nil
}
// Look for a block in direction of account2->account1
block2, err := r.getBlock(ctx, account2, account1)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (block2 != nil), nil
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
// Fetch block from database
block, err := r.getBlock(ctx, account1, account2)
if err != nil {
return nil, err
}
// Set the block originating account
block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID)
if err != nil {
return nil, err
}
// Set the block target account
block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID)
if err != nil {
return nil, err
}
return block, nil
}
func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) {
var block gtsmodel.Block
q := r.conn.NewSelect().Model(&block).
Where("? = ?", bun.Ident("block.account_id"), account1).
Where("? = ?", bun.Ident("block.target_account_id"), account2)
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
return &block, nil
}, account1, account2)
}
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error {
return r.state.Caches.GTS.Block().Store(block, func() error {
_, err := r.conn.NewInsert().Model(block).Exec(ctx)
return r.conn.ProcessError(err)
})
}
func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error {
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Where("? = ?", bun.Ident("block.id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Drop any old value from cache by this ID
r.state.Caches.GTS.Block().Invalidate("ID", id)
return nil
}
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error {
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Where("? = ?", bun.Ident("block.uri"), uri).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Drop any old value from cache by this URI
r.state.Caches.GTS.Block().Invalidate("URI", uri)
return nil
}
func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error {
blockIDs := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Column("block.id").
Where("? = ?", bun.Ident("block.account_id"), originAccountID)
if err := q.Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err)
}
for _, blockID := range blockIDs {
if err := r.DeleteBlockByID(ctx, blockID); err != nil {
return err
}
}
return nil
}
func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error {
blockIDs := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Column("block.id").
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID)
if err := q.Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err)
}
for _, blockID := range blockIDs {
if err := r.DeleteBlockByID(ctx, blockID); err != nil {
return err
}
}
return nil
}
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
var rel gtsmodel.Relationship
rel.ID = targetAccount
// check if the requesting follows the target
follow, err := r.GetFollow(
gtscontext.SetBarebones(ctx),
requestingAccount,
targetAccount,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := r.conn.
NewSelect().
Model(follow).
Column("follow.show_reblogs", "follow.notify").
Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
Limit(1).
Scan(ctx); err != nil {
if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
}
// no follow exists so these are all false
rel.Following = false
rel.ShowingReblogs = false
rel.Notifying = false
} else {
if follow != nil {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = *follow.ShowReblogs
rel.Notifying = *follow.Notify
}
// check if the target account follows the requesting account
followedByQ := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id").
Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
followedBy, err := r.conn.Exists(ctx, followedByQ)
// check if the target follows the requesting
rel.FollowedBy, err = r.IsFollowing(ctx,
targetAccount,
requestingAccount,
)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
}
rel.FollowedBy = followedBy
// check if there's a pending following request from requesting account to target account
requestedQ := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
requested, err := r.conn.Exists(ctx, requestedQ)
// check if requesting has follow requested target
rel.Requested, err = r.IsFollowRequested(ctx,
requestingAccount,
targetAccount,
)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
}
rel.Requested = requested
// check if the requesting account is blocking the target account
blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
}
rel.Blocking = (blockA2T != nil)
// check if the requesting account is blocked by the target account
blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
}
rel.BlockedBy = (blockT2A != nil)
return rel, nil
}
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id").
Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, account1, account2)
rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
if err != nil {
return false, err
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, err
return &rel, nil
}
return f1 && f2, nil
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// Get original follow request.
var followRequestID string
if err := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Scan(ctx, &followRequestID); err != nil {
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectFollows(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
return r.GetFollowsByIDs(ctx, followIDs)
}
followRequest, err := r.getFollowRequest(ctx, followRequestID)
if err != nil {
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectLocalFollows(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
// Create a new follow to 'replace'
// the original follow request with.
follow := &gtsmodel.Follow{
ID: followRequest.ID,
AccountID: originAccountID,
Account: followRequest.Account,
TargetAccountID: targetAccountID,
TargetAccount: followRequest.TargetAccount,
URI: followRequest.URI,
return r.GetFollowsByIDs(ctx, followIDs)
}
// If the follow already exists, just
// replace the URI with the new one.
if _, err := r.conn.
NewInsert().
Model(follow).
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx); err != nil {
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectFollowers(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
return r.GetFollowsByIDs(ctx, followIDs)
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
Exec(ctx); err != nil {
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectLocalFollowers(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request notification.
if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
return nil, err
return r.GetFollowsByIDs(ctx, followIDs)
}
// return the new follow
return follow, nil
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollows(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
// Get original follow request.
var followRequestID string
if err := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Scan(ctx, &followRequestID); err != nil {
func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string
if err := newSelectFollowRequests(r.conn, accountID).
Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
followRequest, err := r.getFollowRequest(ctx, followRequestID)
if err != nil {
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string
if err := newSelectFollowRequesting(r.conn, accountID).
Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
// Delete original follow request notification.
if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
return nil, err
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
// Return the now deleted follow request.
return followRequest, nil
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) deleteFollowRequestNotif(ctx context.Context, originAccountID string, targetAccountID string) db.Error {
var id string
if err := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Column("notification.id").
Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID).
Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("notification.notification_type"), gtsmodel.NotificationFollowRequest).
Limit(1). // There should only be one!
Scan(ctx, &id); err != nil {
err = r.conn.ProcessError(err)
if errors.Is(err, db.ErrNoEntries) {
// If no entries, the notif didn't
// exist anyway so nothing to do here.
return nil
}
// Return on real error.
return err
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
return r.state.DB.DeleteNotification(ctx, id)
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
func (r *relationshipDB) getFollow(ctx context.Context, id string) (*gtsmodel.Follow, db.Error) {
follow := &gtsmodel.Follow{}
err := r.conn.
NewSelect().
Model(follow).
Where("? = ?", bun.Ident("follow.id"), id).
Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ?", bun.Ident("account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
follow.Account, err = r.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil {
log.Errorf(ctx, "error getting follow account %q: %v", follow.AccountID, err)
}
follow.TargetAccount, err = r.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil {
log.Errorf(ctx, "error getting follow target account %q: %v", follow.TargetAccountID, err)
}
return follow, nil
}
func (r *relationshipDB) GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, db.Error) {
accountIDs := []string{}
// Select only the account ID of each follow.
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
ColumnExpr("? AS ?", bun.Ident("follow.account_id"), bun.Ident("account_id")).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
// Join on accounts table to select only
// those with NULL domain (local accounts).
q = q.
Join("JOIN ? AS ? ON ? = ?",
bun.Ident("accounts"),
bun.Ident("account"),
bun.Ident("follow.account_id"),
bun.Ident("account.id"),
// newSelectLocalFollows returns a new select query for all rows in the follows table with
// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ? AND ? IN (?)",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
conn.NewSelect().
Table("accounts").
Column("id").
Where("? IS NULL", bun.Ident("domain")),
).
Where("? IS NULL", bun.Ident("account.domain"))
// We don't *really* need to order these,
// but it makes it more consistent to do so.
q = q.Order("account_id DESC")
if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, r.conn.ProcessError(err)
OrderExpr("? DESC", bun.Ident("updated_at"))
}
return accountIDs, nil
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
func (r *relationshipDB) GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, db.Error) {
ids := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id").
Order("follow.updated_at DESC")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
}
if err := q.Scan(ctx, &ids); err != nil {
return nil, r.conn.ProcessError(err)
}
follows := make([]*gtsmodel.Follow, 0, len(ids))
for _, id := range ids {
follow, err := r.getFollow(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow %q: %v", id, err)
continue
}
follows = append(follows, follow)
}
return follows, nil
}
func (r *relationshipDB) CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
}
return q.Count(ctx)
}
func (r *relationshipDB) getFollowRequest(ctx context.Context, id string) (*gtsmodel.FollowRequest, db.Error) {
followRequest := &gtsmodel.FollowRequest{}
err := r.conn.
NewSelect().
Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), id).
Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
followRequest.Account, err = r.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil {
log.Errorf(ctx, "error getting follow request account %q: %v", followRequest.AccountID, err)
}
followRequest.TargetAccount, err = r.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil {
log.Errorf(ctx, "error getting follow request target account %q: %v", followRequest.TargetAccountID, err)
}
return followRequest, nil
}
func (r *relationshipDB) GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, db.Error) {
ids := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
}
if err := q.Scan(ctx, &ids); err != nil {
return nil, r.conn.ProcessError(err)
}
followRequests := make([]*gtsmodel.FollowRequest, 0, len(ids))
for _, id := range ids {
followRequest, err := r.getFollowRequest(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow request %q: %v", id, err)
continue
}
followRequests = append(followRequests, followRequest)
}
return followRequests, nil
}
func (r *relationshipDB) CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Order("follow_request.updated_at DESC")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
}
return q.Count(ctx)
}
func (r *relationshipDB) Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
uri := new(string)
_, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("follow.account_id"), originAccountID).
Returning("?", bun.Ident("uri")).Exec(ctx, uri)
// Only return proper errors.
if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
return *uri, err
}
return *uri, nil
}
func (r *relationshipDB) UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
uri := new(string)
_, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Returning("?", bun.Ident("uri")).Exec(ctx, uri)
// Only return proper errors.
if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
return *uri, err
}
return *uri, nil
// newSelectLocalFollowers returns a new select query for all rows in the follows table with
// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ? AND ? IN (?)",
bun.Ident("target_account_id"),
accountID,
bun.Ident("account_id"),
conn.NewSelect().
Table("accounts").
Column("id").
Where("? IS NULL", bun.Ident("domain")),
).
OrderExpr("? DESC", bun.Ident("updated_at"))
}

View File

@ -0,0 +1,218 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
block, err := r.GetBlock(
gtscontext.SetBarebones(ctx),
sourceAccountID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (block != nil), nil
}
func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
// Look for a block in direction of account1->account2
b1, err := r.IsBlocked(ctx, accountID1, accountID2)
if err != nil || b1 {
return true, err
}
// Look for a block in direction of account2->account1
b2, err := r.IsBlocked(ctx, accountID2, accountID1)
if err != nil || b2 {
return true, err
}
return false, nil
}
func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
"ID",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.id"), id).
Scan(ctx)
},
id,
)
}
func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
"URI",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.uri"), uri).
Scan(ctx)
},
uri,
)
}
func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
"AccountID.TargetAccountID",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
Scan(ctx)
},
sourceAccountID,
targetAccountID,
)
}
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
// Fetch block from cache with loader callback
block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
var block gtsmodel.Block
// Not cached! Perform database query
if err := dbQuery(&block); err != nil {
return nil, r.conn.ProcessError(err)
}
return &block, nil
}, keyParts...)
if err != nil {
// already processe
return nil, err
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return block, nil
}
// Set the block source account
block.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
block.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting block source account: %w", err)
}
// Set the block target account
block.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
block.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting block target account: %w", err)
}
return block, nil
}
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
err := r.state.Caches.GTS.Block().Store(block, func() error {
_, err := r.conn.NewInsert().Model(block).Exec(ctx)
return r.conn.ProcessError(err)
})
if err != nil {
return err
}
// Invalidate block origin account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", block.AccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", block.AccountID)
// Invalidate block target account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", block.TargetAccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", block.TargetAccountID)
return nil
}
func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
block, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id)
if err != nil {
return err
}
return r.deleteBlock(ctx, block)
}
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
block, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri)
if err != nil {
return err
}
return r.deleteBlock(ctx, block)
}
func (r *relationshipDB) deleteBlock(ctx context.Context, block *gtsmodel.Block) error {
if _, err := r.conn.
NewDelete().
Table("blocks").
Where("? = ?", bun.Ident("id"), block.ID).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate block from cache lookups.
r.state.Caches.GTS.Block().Invalidate("ID", block.ID)
return nil
}
func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
var blockIDs []string
if err := r.conn.NewSelect().
Table("blocks").
ColumnExpr("?", bun.Ident("id")).
WhereOr("? = ? OR ? = ?",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
accountID,
).
Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err)
}
for _, id := range blockIDs {
if err := r.DeleteBlockByID(ctx, id); err != nil {
log.Errorf(ctx, "error deleting block %q: %v", id, err)
}
}
return nil
}

View File

@ -0,0 +1,243 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
"ID",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
"URI",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx)
},
uri,
)
}
func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
"AccountID.TargetAccountID",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Scan(ctx)
},
sourceAccountID,
targetAccountID,
)
}
func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
// Preallocate slice of expected length.
follows := make([]*gtsmodel.Follow, 0, len(ids))
for _, id := range ids {
// Fetch follow model for this ID.
follow, err := r.GetFollowByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow %q: %v", id, err)
continue
}
// Append to return slice.
follows = append(follows, follow)
}
return follows, nil
}
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
follow, err := r.GetFollow(
gtscontext.SetBarebones(ctx),
sourceAccountID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (follow != nil), nil
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) {
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx,
accountID1,
accountID2,
)
if !f1 /* f1 = false when err != nil */ {
return false, err
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx,
accountID2,
accountID1,
)
if !f2 /* f2 = false when err != nil */ {
return false, err
}
return true, nil
}
func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
// Fetch follow from database cache with loader callback
follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
var follow gtsmodel.Follow
// Not cached! Perform database query
if err := dbQuery(&follow); err != nil {
return nil, r.conn.ProcessError(err)
}
return &follow, nil
}, keyParts...)
if err != nil {
// error already processed
return nil, err
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return follow, nil
}
// Set the follow source account
follow.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow source account: %w", err)
}
// Set the follow target account
follow.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow target account: %w", err)
}
return follow, nil
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
err := r.state.Caches.GTS.Follow().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err)
})
if err != nil {
return err
}
// Invalidate follow origin account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow target account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
return nil
}
func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
if _, err := r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow from cache lookups.
r.state.Caches.GTS.Follow().Invalidate("ID", id)
return nil
}
func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
if _, err := r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow from cache lookups.
r.state.Caches.GTS.Follow().Invalidate("URI", uri)
return nil
}
func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
var followIDs []string
if _, err := r.conn.
NewDelete().
Table("follows").
WhereOr("? = ? OR ? = ?",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
accountID,
).
Returning("?", bun.Ident("id")).
Exec(ctx, &followIDs); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate each returned ID.
for _, id := range followIDs {
r.state.Caches.GTS.Follow().Invalidate("ID", id)
}
return nil
}

View File

@ -0,0 +1,293 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
"ID",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
"URI",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx)
},
uri,
)
}
func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
"AccountID.TargetAccountID",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Scan(ctx)
},
sourceAccountID,
targetAccountID,
)
}
func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
// Preallocate slice of expected length.
followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
for _, id := range ids {
// Fetch follow request model for this ID.
followReq, err := r.GetFollowRequestByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow request %q: %v", id, err)
continue
}
// Append to return slice.
followReqs = append(followReqs, followReq)
}
return followReqs, nil
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
followReq, err := r.GetFollowRequest(
gtscontext.SetBarebones(ctx),
sourceAccountID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (followReq != nil), nil
}
func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
// Fetch follow request from database cache with loader callback
followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
var followReq gtsmodel.FollowRequest
// Not cached! Perform database query
if err := dbQuery(&followReq); err != nil {
return nil, r.conn.ProcessError(err)
}
return &followReq, nil
}, keyParts...)
if err != nil {
// error already processed
return nil, err
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return followReq, nil
}
// Set the follow request source account
followReq.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
followReq.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow request source account: %w", err)
}
// Set the follow request target account
followReq.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
followReq.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow request target account: %w", err)
}
return followReq, nil
}
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
err := r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err)
})
if err != nil {
return err
}
// Invalidate follow request origin account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow request target account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
return nil
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// Get original follow request.
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
if err != nil {
return nil, err
}
// Create a new follow to 'replace'
// the original follow request with.
follow := &gtsmodel.Follow{
ID: followReq.ID,
AccountID: sourceAccountID,
Account: followReq.Account,
TargetAccountID: targetAccountID,
TargetAccount: followReq.TargetAccount,
URI: followReq.URI,
}
// If the follow already exists, just
// replace the URI with the new one.
if _, err := r.conn.
NewInsert().
Model(follow).
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), followReq.ID).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
// Invalidate follow request from cache lookups.
r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
// Delete original follow request notification
if err := r.state.DB.DeleteNotifications(ctx, []string{
string(gtsmodel.NotificationFollowRequest),
}, targetAccountID, sourceAccountID); err != nil {
return nil, err
}
return follow, nil
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error {
// Get original follow request.
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
if err != nil {
return err
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), followReq.ID).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Delete original follow request notification
return r.state.DB.DeleteNotifications(ctx, []string{
string(gtsmodel.NotificationFollowRequest),
}, targetAccountID, sourceAccountID)
}
func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
if _, err := r.conn.NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow request from cache lookups.
r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
return nil
}
func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
if _, err := r.conn.NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow request from cache lookups.
r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
return nil
}
func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
var followIDs []string
if _, err := r.conn.
NewDelete().
Table("follow_requests").
WhereOr("? = ? OR ? = ?",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
accountID,
).
Returning("?", bun.Ident("id")).
Exec(ctx, &followIDs); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate each returned ID.
for _, id := range followIDs {
r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
}
return nil
}

View File

@ -19,17 +19,359 @@ package bundb_test
import (
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
type RelationshipTestSuite struct {
BunDBStandardTestSuite
}
func (suite *RelationshipTestSuite) TestGetBlockBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 block models are equal.
isEqual := func(b1, b2 gtsmodel.Block) bool {
// Clear populated sub-models.
b1.Account = nil
b2.Account = nil
b1.TargetAccount = nil
b2.TargetAccount = nil
// Clear database-set fields.
b1.CreatedAt = time.Time{}
b2.CreatedAt = time.Time{}
b1.UpdatedAt = time.Time{}
b2.UpdatedAt = time.Time{}
return reflect.DeepEqual(b1, b2)
}
var testBlocks []*gtsmodel.Block
for _, account1 := range suite.testAccounts {
for _, account2 := range suite.testAccounts {
if account1.ID == account2.ID {
// don't block *yourself* ...
continue
}
// Create new account block.
block := &gtsmodel.Block{
ID: id.NewULID(),
URI: "http://127.0.0.1:8080/" + id.NewULID(),
AccountID: account1.ID,
TargetAccountID: account2.ID,
}
// Attempt to place the block in database (if not already).
if err := suite.db.PutBlock(ctx, block); err != nil {
if err != db.ErrAlreadyExists {
// Unrecoverable database error.
t.Fatalf("error creating block: %v", err)
}
// Fetch existing block from database between accounts.
block, _ = suite.db.GetBlock(ctx, account1.ID, account2.ID)
continue
}
// Append generated block to test cases.
testBlocks = append(testBlocks, block)
}
}
for _, block := range testBlocks {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Block, error){
"id": func() (*gtsmodel.Block, error) {
return suite.db.GetBlockByID(ctx, block.ID)
},
"uri": func() (*gtsmodel.Block, error) {
return suite.db.GetBlockByURI(ctx, block.URI)
},
"origin_target": func() (*gtsmodel.Block, error) {
return suite.db.GetBlock(ctx, block.AccountID, block.TargetAccountID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkBlock, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received block data.
if !isEqual(*checkBlock, *block) {
t.Errorf("block does not contain expected data: %+v", checkBlock)
continue
}
// Check that block origin account populated.
if checkBlock.Account == nil || checkBlock.Account.ID != block.AccountID {
t.Errorf("block origin account not correctly populated for: %+v", checkBlock)
continue
}
// Check that block target account populated.
if checkBlock.TargetAccount == nil || checkBlock.TargetAccount.ID != block.TargetAccountID {
t.Errorf("block target account not correctly populated for: %+v", checkBlock)
continue
}
}
}
}
func (suite *RelationshipTestSuite) TestGetFollowBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 follow models are equal.
isEqual := func(f1, f2 gtsmodel.Follow) bool {
// Clear populated sub-models.
f1.Account = nil
f2.Account = nil
f1.TargetAccount = nil
f2.TargetAccount = nil
// Clear database-set fields.
f1.CreatedAt = time.Time{}
f2.CreatedAt = time.Time{}
f1.UpdatedAt = time.Time{}
f2.UpdatedAt = time.Time{}
return reflect.DeepEqual(f1, f2)
}
var testFollows []*gtsmodel.Follow
for _, account1 := range suite.testAccounts {
for _, account2 := range suite.testAccounts {
if account1.ID == account2.ID {
// don't follow *yourself* ...
continue
}
// Create new account follow.
follow := &gtsmodel.Follow{
ID: id.NewULID(),
URI: "http://127.0.0.1:8080/" + id.NewULID(),
AccountID: account1.ID,
TargetAccountID: account2.ID,
}
// Attempt to place the follow in database (if not already).
if err := suite.db.PutFollow(ctx, follow); err != nil {
if err != db.ErrAlreadyExists {
// Unrecoverable database error.
t.Fatalf("error creating follow: %v", err)
}
// Fetch existing follow from database between accounts.
follow, _ = suite.db.GetFollow(ctx, account1.ID, account2.ID)
continue
}
// Append generated follow to test cases.
testFollows = append(testFollows, follow)
}
}
for _, follow := range testFollows {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Follow, error){
"id": func() (*gtsmodel.Follow, error) {
return suite.db.GetFollowByID(ctx, follow.ID)
},
"uri": func() (*gtsmodel.Follow, error) {
return suite.db.GetFollowByURI(ctx, follow.URI)
},
"origin_target": func() (*gtsmodel.Follow, error) {
return suite.db.GetFollow(ctx, follow.AccountID, follow.TargetAccountID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkFollow, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received follow data.
if !isEqual(*checkFollow, *follow) {
t.Errorf("follow does not contain expected data: %+v", checkFollow)
continue
}
// Check that follow origin account populated.
if checkFollow.Account == nil || checkFollow.Account.ID != follow.AccountID {
t.Errorf("follow origin account not correctly populated for: %+v", checkFollow)
continue
}
// Check that follow target account populated.
if checkFollow.TargetAccount == nil || checkFollow.TargetAccount.ID != follow.TargetAccountID {
t.Errorf("follow target account not correctly populated for: %+v", checkFollow)
continue
}
}
}
}
func (suite *RelationshipTestSuite) TestGetFollowRequestBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 follow request models are equal.
isEqual := func(f1, f2 gtsmodel.FollowRequest) bool {
// Clear populated sub-models.
f1.Account = nil
f2.Account = nil
f1.TargetAccount = nil
f2.TargetAccount = nil
// Clear database-set fields.
f1.CreatedAt = time.Time{}
f2.CreatedAt = time.Time{}
f1.UpdatedAt = time.Time{}
f2.UpdatedAt = time.Time{}
return reflect.DeepEqual(f1, f2)
}
var testFollowReqs []*gtsmodel.FollowRequest
for _, account1 := range suite.testAccounts {
for _, account2 := range suite.testAccounts {
if account1.ID == account2.ID {
// don't follow *yourself* ...
continue
}
// Create new account follow request.
followReq := &gtsmodel.FollowRequest{
ID: id.NewULID(),
URI: "http://127.0.0.1:8080/" + id.NewULID(),
AccountID: account1.ID,
TargetAccountID: account2.ID,
}
// Attempt to place the follow in database (if not already).
if err := suite.db.PutFollowRequest(ctx, followReq); err != nil {
if err != db.ErrAlreadyExists {
// Unrecoverable database error.
t.Fatalf("error creating follow request: %v", err)
}
// Fetch existing follow request from database between accounts.
followReq, _ = suite.db.GetFollowRequest(ctx, account1.ID, account2.ID)
continue
}
// Append generated follow request to test cases.
testFollowReqs = append(testFollowReqs, followReq)
}
}
for _, followReq := range testFollowReqs {
for lookup, dbfunc := range map[string]func() (*gtsmodel.FollowRequest, error){
"id": func() (*gtsmodel.FollowRequest, error) {
return suite.db.GetFollowRequestByID(ctx, followReq.ID)
},
"uri": func() (*gtsmodel.FollowRequest, error) {
return suite.db.GetFollowRequestByURI(ctx, followReq.URI)
},
"origin_target": func() (*gtsmodel.FollowRequest, error) {
return suite.db.GetFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkFollowReq, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received follow request data.
if !isEqual(*checkFollowReq, *followReq) {
t.Errorf("follow request does not contain expected data: %+v", checkFollowReq)
continue
}
// Check that follow request origin account populated.
if checkFollowReq.Account == nil || checkFollowReq.Account.ID != followReq.AccountID {
t.Errorf("follow request origin account not correctly populated for: %+v", checkFollowReq)
continue
}
// Check that follow request target account populated.
if checkFollowReq.TargetAccount == nil || checkFollowReq.TargetAccount.ID != followReq.TargetAccountID {
t.Errorf("follow request target account not correctly populated for: %+v", checkFollowReq)
continue
}
}
}
}
func (suite *RelationshipTestSuite) TestIsBlocked() {
ctx := context.Background()
@ -37,11 +379,11 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
account2 := suite.testAccounts["local_account_2"].ID
// no blocks exist between account 1 and account 2
blocked, err := suite.db.IsBlocked(ctx, account1, account2, false)
blocked, err := suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.False(blocked)
blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
@ -56,45 +398,24 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
}
// account 1 now blocks account 2
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
blocked, err = suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
// account 2 doesn't block account 1
blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
// a block exists in either direction between the two
blocked, err = suite.db.IsBlocked(ctx, account1, account2, true)
blocked, err = suite.db.IsEitherBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
blocked, err = suite.db.IsBlocked(ctx, account2, account1, true)
blocked, err = suite.db.IsEitherBlocked(ctx, account2, account1)
suite.NoError(err)
suite.True(blocked)
}
func (suite *RelationshipTestSuite) TestGetBlock() {
ctx := context.Background()
account1 := suite.testAccounts["local_account_1"].ID
account2 := suite.testAccounts["local_account_2"].ID
if err := suite.db.PutBlock(ctx, &gtsmodel.Block{
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
URI: "http://localhost:8080/some_block_uri_1",
AccountID: account1,
TargetAccountID: account2,
}); err != nil {
suite.FailNow(err.Error())
}
block, err := suite.db.GetBlock(ctx, account1, account2)
suite.NoError(err)
suite.NotNil(block)
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
}
func (suite *RelationshipTestSuite) TestDeleteBlockByID() {
ctx := context.Background()
@ -157,7 +478,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlockByURI() {
suite.Nil(block)
}
func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
func (suite *RelationshipTestSuite) TestDeleteAccountBlocks() {
ctx := context.Background()
// put a block in first
@ -179,38 +500,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
// delete the block by originAccountID
err = suite.db.DeleteBlocksByOriginAccountID(ctx, account1)
suite.NoError(err)
// block should be gone
block, err = suite.db.GetBlock(ctx, account1, account2)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(block)
}
func (suite *RelationshipTestSuite) TestDeleteBlocksByTargetAccountID() {
ctx := context.Background()
// put a block in first
account1 := suite.testAccounts["local_account_1"].ID
account2 := suite.testAccounts["local_account_2"].ID
if err := suite.db.PutBlock(ctx, &gtsmodel.Block{
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
URI: "http://localhost:8080/some_block_uri_1",
AccountID: account1,
TargetAccountID: account2,
}); err != nil {
suite.FailNow(err.Error())
}
// make sure the block is in the db
block, err := suite.db.GetBlock(ctx, account1, account2)
suite.NoError(err)
suite.NotNil(block)
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
// delete the block by targetAccountID
err = suite.db.DeleteBlocksByTargetAccountID(ctx, account2)
err = suite.db.DeleteAccountBlocks(ctx, account1)
suite.NoError(err)
// block should be gone
@ -244,7 +534,7 @@ func (suite *RelationshipTestSuite) TestGetRelationship() {
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isFollowing)
}
@ -252,7 +542,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingYes() {
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
requestingAccount := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.False(isFollowing)
}
@ -260,7 +550,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingNo() {
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@ -268,7 +558,7 @@ func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@ -306,7 +596,7 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
suite.Equal(followRequest.URI, follow.URI)
// Ensure notification is deleted.
notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@ -389,7 +679,7 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
TargetAccountID: targetAccount.ID,
}
if err := suite.db.Put(ctx, followRequest); err != nil {
if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
@ -404,12 +694,11 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
suite.FailNow(err.Error())
}
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(rejectedFollowRequest)
// Ensure notification is deleted.
notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@ -419,9 +708,8 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(rejectedFollowRequest)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
@ -440,42 +728,49 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
suite.FailNow(err.Error())
}
followRequests, err := suite.db.GetFollowRequests(ctx, "", targetAccount.ID)
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
suite.NoError(err)
suite.Len(followRequests, 1)
}
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetFollows(context.Background(), account.ID, "")
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountFollows(context.Background(), account.ID, "")
followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetFollows(context.Background(), "", account.ID)
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
func (suite *RelationshipTestSuite) TestGetLocalFollowersIDs() {
func (suite *RelationshipTestSuite) TestCountAccountFollowers() {
account := suite.testAccounts["local_account_1"]
accountIDs, err := suite.db.GetLocalFollowersIDs(context.Background(), account.ID)
followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.EqualValues([]string{"01F8MH5NBDF2MV7CTC4Q5128HF", "01F8MH17FWEB39HZJ76B6VXSKF"}, accountIDs)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountFollows(context.Background(), "", account.ID)
followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
@ -484,18 +779,25 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccount.ID)
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.Equal("http://localhost:8080/users/the_mighty_zork/follow/01F8PY8RHWRQZV038T4E8T9YK8", uri)
suite.NotNil(follow)
err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
suite.NoError(err)
follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccountID)
suite.NoError(err)
suite.Empty(uri)
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccountID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
@ -510,22 +812,29 @@ func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
TargetAccountID: targetAccount.ID,
}
if err := suite.db.Put(ctx, followRequest); err != nil {
if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.Equal("http://localhost:8080/weeeeeeeeeeeeeeeee", uri)
suite.NotNil(followRequest)
err = suite.db.DeleteFollowRequestByID(context.Background(), followRequest.ID)
suite.NoError(err)
followRequest, err = suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(followRequest)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccountID)
suite.NoError(err)
suite.Empty(uri)
followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccountID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(followRequest)
}
func TestRelationshipTestSuite(t *testing.T) {

View File

@ -26,6 +26,7 @@ import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -41,7 +42,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(status).
Relation("Attachments").
Relation("Tags").
Relation("CreatedWithApplication")
}
@ -102,81 +102,143 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
// Not cached! Perform database query
// Not cached! Perform database query.
if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err)
}
if status.InReplyToID != "" {
// Also load in-reply-to status
status.InReplyTo = new(gtsmodel.Status)
err := s.conn.NewSelect().Model(status.InReplyTo).
Where("? = ?", bun.Ident("status.id"), status.InReplyToID).
Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
}
if status.BoostOfID != "" {
// Also load original boosted status
status.BoostOf = new(gtsmodel.Status)
err := s.conn.NewSelect().Model(status.BoostOf).
Where("? = ?", bun.Ident("status.id"), status.BoostOfID).
Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
}
return &status, nil
}, keyParts...)
if err != nil {
// error already processed
return nil, err
}
// Set the status author account
status.Account, err = s.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil {
return nil, fmt.Errorf("error getting status account: %w", err)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return status, nil
}
if id := status.BoostOfAccountID; id != "" {
// Set boost of status' author account
status.BoostOfAccount, err = s.state.DB.GetAccountByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("error getting boosted status account: %w", err)
}
}
if id := status.InReplyToAccountID; id != "" {
// Set in-reply-to status' author account
status.InReplyToAccount, err = s.state.DB.GetAccountByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("error getting in reply to status account: %w", err)
}
}
if len(status.EmojiIDs) > 0 {
// Fetch status emojis
status.Emojis, err = s.state.DB.GetEmojisByIDs(ctx, status.EmojiIDs)
if err != nil {
return nil, fmt.Errorf("error getting status emojis: %w", err)
}
}
if len(status.MentionIDs) > 0 {
// Fetch status mentions
status.Mentions, err = s.state.DB.GetMentions(ctx, status.MentionIDs)
if err != nil {
return nil, fmt.Errorf("error getting status mentions: %w", err)
}
// Further populate the status fields where applicable.
if err := s.PopulateStatus(ctx, status); err != nil {
return nil, err
}
return status, nil
}
func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error {
var err error
if status.Account == nil {
// Status author is not set, fetch from database.
status.Account, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
status.AccountID,
)
if err != nil {
return fmt.Errorf("error populating status author: %w", err)
}
}
if status.InReplyToID != "" && status.InReplyTo == nil {
// Status parent is not set, fetch from database.
status.InReplyTo, err = s.GetStatusByID(
gtscontext.SetBarebones(ctx),
status.InReplyToID,
)
if err != nil {
return fmt.Errorf("error populating status parent: %w", err)
}
}
if status.InReplyToID != "" {
if status.InReplyTo == nil {
// Status parent is not set, fetch from database.
status.InReplyTo, err = s.GetStatusByID(
gtscontext.SetBarebones(ctx),
status.InReplyToID,
)
if err != nil {
return fmt.Errorf("error populating status parent: %w", err)
}
}
if status.InReplyToAccount == nil {
// Status parent author is not set, fetch from database.
status.InReplyToAccount, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
status.InReplyToAccountID,
)
if err != nil {
return fmt.Errorf("error populating status parent author: %w", err)
}
}
}
if status.BoostOfID != "" {
if status.BoostOf == nil {
// Status boost is not set, fetch from database.
status.BoostOf, err = s.GetStatusByID(
gtscontext.SetBarebones(ctx),
status.BoostOfID,
)
if err != nil {
return fmt.Errorf("error populating status boost: %w", err)
}
}
if status.BoostOfAccount == nil {
// Status boost author is not set, fetch from database.
status.BoostOfAccount, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
status.BoostOfAccountID,
)
if err != nil {
return fmt.Errorf("error populating status boost author: %w", err)
}
}
}
if !status.AttachmentsPopulated() {
// Status attachments are out-of-date with IDs, repopulate.
status.Attachments, err = s.state.DB.GetAttachmentsByIDs(
ctx, // these are already barebones
status.AttachmentIDs,
)
if err != nil {
return fmt.Errorf("error populating status attachments: %w", err)
}
}
// TODO: once we don't fetch using relations.
// if !status.TagsPopulated() {
// }
if !status.MentionsPopulated() {
// Status mentions are out-of-date with IDs, repopulate.
status.Mentions, err = s.state.DB.GetMentions(
ctx, // leave fully populated for now
status.MentionIDs,
)
if err != nil {
return fmt.Errorf("error populating status mentions: %w", err)
}
}
if !status.EmojisPopulated() {
// Status emojis are out-of-date with IDs, repopulate.
status.Emojis, err = s.state.DB.GetEmojisByIDs(
ctx, // these are already barebones
status.EmojiIDs,
)
if err != nil {
return fmt.Errorf("error populating status emojis: %w", err)
}
}
return nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
err := s.state.Caches.GTS.Status().Store(status, func() error {
// It is safe to run this database transaction within cache.Store
@ -239,12 +301,16 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
})
})
if err != nil {
// already processed
return err
}
for _, id := range status.AttachmentIDs {
// Clear updated media attachment IDs from cache
// Invalidate media attachments from cache.
//
// NOTE: this is needed due to the way in which
// we upload status attachments, and only after
// update them with a known status ID. This is
// not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
@ -322,14 +388,19 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
return err
}
// Invalidate status from database lookups.
s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
for _, id := range status.AttachmentIDs {
// Clear updated media attachment IDs from cache
// Invalidate media attachments from cache.
//
// NOTE: this is needed due to the way in which
// we upload status attachments, and only after
// update them with a known status ID. This is
// not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
// Drop any old status value from cache by this ID
s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
return nil
}
@ -367,8 +438,12 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
return err
}
// Drop any old value from cache by this ID
// Invalidate status from database lookups.
s.state.Caches.GTS.Status().Invalidate("ID", id)
// Invalidate status from all visibility lookups.
s.state.Caches.Visibility.Invalidate("ItemID", id)
return nil
}

View File

@ -23,6 +23,7 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -34,29 +35,82 @@ type statusFaveDB struct {
state *state.State
}
func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
fave := new(gtsmodel.StatusFave)
err := s.conn.
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
return s.getStatusFave(
ctx,
"AccountID.StatusID",
func(fave *gtsmodel.StatusFave) error {
return s.conn.
NewSelect().
Model(fave).
Where("? = ?", bun.Ident("status_fave.ID"), id).
Where("? = ?", bun.Ident("account_id"), accountID).
Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx)
if err != nil {
},
accountID,
statusID,
)
}
func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
return s.getStatusFave(
ctx,
"ID",
func(fave *gtsmodel.StatusFave) error {
return s.conn.
NewSelect().
Model(fave).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
// Fetch status fave from database cache with loader callback
fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) {
var fave gtsmodel.StatusFave
// Not cached! Perform database query.
if err := dbQuery(&fave); err != nil {
return nil, s.conn.ProcessError(err)
}
fave.Account, err = s.state.DB.GetAccountByID(ctx, fave.AccountID)
return &fave, nil
}, keyParts...)
if err != nil {
return nil, err
}
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return fave, nil
}
// Fetch the status fave author account.
fave.Account, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
fave.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err)
}
fave.TargetAccount, err = s.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
// Fetch the status fave target account.
fave.TargetAccount, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
fave.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err)
}
fave.Status, err = s.state.DB.GetStatusByID(ctx, fave.StatusID)
// Fetch the status fave target status.
fave.Status, err = s.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
fave.StatusID,
)
if err != nil {
return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err)
}
@ -64,38 +118,22 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.
return fave, nil
}
func (s *statusFaveDB) GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
var id string
err := s.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Column("status_fave.id").
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
Scan(ctx, &id)
if err != nil {
return nil, s.conn.ProcessError(err)
}
return s.GetStatusFave(ctx, id)
}
func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
ids := []string{}
if err := s.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Column("status_fave.id").
Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
Table("status_faves").
Column("id").
Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err)
}
faves := make([]*gtsmodel.StatusFave, 0, len(ids))
for _, id := range ids {
fave, err := s.GetStatusFave(ctx, id)
fave, err := s.GetStatusFaveByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status fave %q: %v", id, err)
continue
@ -107,23 +145,27 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
return faves, nil
}
func (s *statusFaveDB) PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) db.Error {
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error {
return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
_, err := s.conn.
NewInsert().
Model(statusFave).
Model(fave).
Exec(ctx)
return s.conn.ProcessError(err)
})
}
func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error {
if _, err := s.conn.
NewDelete().
Table("status_faves").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return s.conn.ProcessError(err)
}
func (s *statusFaveDB) DeleteStatusFave(ctx context.Context, id string) db.Error {
_, err := s.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.id"), id).
Exec(ctx)
return s.conn.ProcessError(err)
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
return nil
}
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
@ -131,42 +173,52 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
}
// TODO: Capture fave IDs in a RETURNING
// statement (when faves have a cache),
// + use the IDs to invalidate cache entries.
// Capture fave IDs in a RETURNING statement.
var faveIDs []string
q := s.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave"))
Table("status_faves").
Returning("?", bun.Ident("id"))
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("status_fave.target_account_id"), targetAccountID)
q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID)
}
if originAccountID != "" {
q = q.Where("? = ?", bun.Ident("status_fave.account_id"), originAccountID)
q = q.Where("? = ?", bun.Ident("account_id"), originAccountID)
}
if _, err := q.Exec(ctx); err != nil {
if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
for _, id := range faveIDs {
// Invalidate each of the returned status fave IDs.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
}
return nil
}
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error {
// TODO: Capture fave IDs in a RETURNING
// statement (when faves have a cache),
// + use the IDs to invalidate cache entries.
// Capture fave IDs in a RETURNING statement.
var faveIDs []string
q := s.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.status_id"), statusID)
Table("status_faves").
Where("? = ?", bun.Ident("status_id"), statusID).
Returning("?", bun.Ident("id"))
if _, err := q.Exec(ctx); err != nil {
if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
for _, id := range faveIDs {
// Invalidate each of the returned status fave IDs.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
}
return nil
}

View File

@ -35,7 +35,7 @@ type StatusFaveTestSuite struct {
func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
testStatus := suite.testStatuses["admin_account_status_1"]
faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -51,7 +51,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
func (suite *StatusFaveTestSuite) TestGetStatusFavesNone() {
testStatus := suite.testStatuses["admin_account_status_4"]
faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -63,7 +63,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaveByAccountID() {
testAccount := suite.testAccounts["local_account_1"]
testStatus := suite.testStatuses["admin_account_status_1"]
fave, err := suite.db.GetStatusFaveByAccountID(context.Background(), testAccount.ID, testStatus.ID)
fave, err := suite.db.GetStatusFave(context.Background(), testAccount.ID, testStatus.ID)
suite.NoError(err)
suite.NotNil(fave)
}
@ -129,17 +129,17 @@ func (suite *StatusFaveTestSuite) TestDeleteStatusFave() {
testFave := suite.testFaves["local_account_1_admin_account_status_1"]
ctx := context.Background()
if err := suite.db.DeleteStatusFave(ctx, testFave.ID); err != nil {
if err := suite.db.DeleteStatusFaveByID(ctx, testFave.ID); err != nil {
suite.FailNow(err.Error())
}
fave, err := suite.db.GetStatusFave(ctx, testFave.ID)
fave, err := suite.db.GetStatusFaveByID(ctx, testFave.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(fave)
}
func (suite *StatusFaveTestSuite) TestDeleteStatusFaveNonExisting() {
err := suite.db.DeleteStatusFave(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
err := suite.db.DeleteStatusFaveByID(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
suite.NoError(err)
}

View File

@ -61,9 +61,12 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
Order("status.id DESC")
if maxID == "" {
const future = 24 * time.Hour
var err error
// don't return statuses more than five minutes in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
// don't return statuses more than 24hr in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
@ -138,15 +141,16 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")).
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
Order("status.id DESC")
if maxID == "" {
const future = 24 * time.Hour
var err error
// don't return statuses more than five minutes in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
// don't return statuses more than 24hr in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}

View File

@ -34,15 +34,32 @@ type TimelineTestSuite struct {
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
ctx := context.Background()
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, 6)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
futureStatus := getFutureStatus()
@ -53,7 +70,7 @@ func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, 6)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {

View File

@ -29,6 +29,9 @@ type Media interface {
// GetAttachmentByID gets a single attachment by its ID.
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
// GetAttachmentsByIDs fetches a list of media attachments for given IDs.
GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error)
// PutAttachment inserts the given attachment into the database.
PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error

View File

@ -30,4 +30,10 @@ type Mention interface {
// GetMentions gets multiple mentions.
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
// PutMention will insert the given mention into the database.
PutMention(ctx context.Context, mention *gtsmodel.Mention) error
// DeleteMentionByID will delete mention with given ID from the database.
DeleteMentionByID(ctx context.Context, id string) error
}

View File

@ -28,14 +28,17 @@ type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error)
// DeleteNotification deletes one notification according to its id,
// PutNotification will insert the given notification into the database.
PutNotification(ctx context.Context, notif *gtsmodel.Notification) error
// DeleteNotificationByID deletes one notification according to its id,
// and removes that notification from the in-memory cache.
DeleteNotification(ctx context.Context, id string) Error
DeleteNotificationByID(ctx context.Context, id string) Error
// DeleteNotifications mass deletes notifications targeting targetAccountID
// and/or originating from originAccountID.
@ -50,7 +53,7 @@ type Notification interface {
// originate from originAccountID will be deleted.
//
// At least one parameter must not be an empty string.
DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) Error
DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error
// DeleteNotificationsForStatus deletes all notifications that relate to
// the given statusID. This function is useful when a status has been deleted,

View File

@ -25,42 +25,86 @@ import (
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against account2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
// IsBlocked checks whether source account has a block in place against target.
IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsEitherBlocked checks whether there is a block in place between either of account1 and account2.
IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error)
// GetBlockByID fetches block with given ID from the database.
GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error)
// GetBlockByURI fetches block with given AP URI from the database.
GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, error)
// PutBlock attempts to place the given account block in the database.
PutBlock(ctx context.Context, block *gtsmodel.Block) Error
PutBlock(ctx context.Context, block *gtsmodel.Block) error
// DeleteBlockByID removes block with given ID from the database.
DeleteBlockByID(ctx context.Context, id string) Error
DeleteBlockByID(ctx context.Context, id string) error
// DeleteBlockByURI removes block with given AP URI from the database.
DeleteBlockByURI(ctx context.Context, uri string) Error
DeleteBlockByURI(ctx context.Context, uri string) error
// DeleteBlocksByOriginAccountID removes any blocks with accountID equal to originAccountID.
DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) Error
// DeleteBlocksByTargetAccountID removes any blocks with given targetAccountID.
DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) Error
// DeleteAccountBlocks will delete all database blocks to / from the given account ID.
DeleteAccountBlocks(ctx context.Context, accountID string) error
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// GetFollowByID fetches follow with given ID from the database.
GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// GetFollowByURI fetches follow with given AP URI from the database.
GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error)
// GetFollow retrieves a follow if it exists between source and target accounts.
GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// GetFollowRequestByID fetches follow request with given ID from the database.
GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)
// GetFollowRequestByURI fetches follow request with given AP URI from the database.
GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error)
// GetFollowRequest retrieves a follow request if it exists between source and target accounts.
GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// PutFollow attempts to place the given account follow in the database.
PutFollow(ctx context.Context, follow *gtsmodel.Follow) error
// PutFollowRequest attempts to place the given account follow request in the database.
PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error
// DeleteFollowByID deletes a follow from the database with the given ID.
DeleteFollowByID(ctx context.Context, id string) error
// DeleteFollowByURI deletes a follow from the database with the given URI.
DeleteFollowByURI(ctx context.Context, uri string) error
// DeleteFollowRequestByID deletes a follow request from the database with the given ID.
DeleteFollowRequestByID(ctx context.Context, id string) error
// DeleteFollowRequestByURI deletes a follow request from the database with the given URI.
DeleteFollowRequestByURI(ctx context.Context, uri string) error
// DeleteAccountFollows will delete all database follows to / from the given account ID.
DeleteAccountFollows(ctx context.Context, accountID string) error
// DeleteAccountFollowRequests will delete all database follow requests to / from the given account ID.
DeleteAccountFollowRequests(ctx context.Context, accountID string) error
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
@ -69,65 +113,41 @@ type Relationship interface {
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// RejectFollowRequest fetches a follow request from the database, and then deletes it.
//
// The deleted follow request will be returned so that further processing can be done on it.
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, Error)
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error
// GetFollows returns a slice of follows owned by the given accountID, and/or
// targeting the given account id.
//
// If accountID is set and targetAccountID isn't, then all follows created by
// accountID will be returned.
//
// If targetAccountID is set and accountID isn't, then all follows targeting
// targetAccountID will be returned.
//
// If both accountID and targetAccountID are set, then only 0 or 1 follows will
// be in the returned slice.
GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// GetLocalFollowersIDs returns a list of local account IDs which follow the
// targetAccountID. The returned IDs are not guaranteed to be ordered in any
// particular way, so take care.
GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, Error)
// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// CountFollows is like GetFollows, but just counts rather than returning.
CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
CountAccountFollows(ctx context.Context, accountID string) (int, error)
// GetFollowRequests returns a slice of follows requests owned by the given
// accountID, and/or targeting the given account id.
//
// If accountID is set and targetAccountID isn't, then all requests created by
// accountID will be returned.
//
// If targetAccountID is set and accountID isn't, then all requests targeting
// targetAccountID will be returned.
//
// If both accountID and targetAccountID are set, then only 0 or 1 requests will
// be in the returned slice.
GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, Error)
// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
// CountFollowRequests is like GetFollowRequests, but just counts rather than returning.
CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, Error)
// GetAccountFollowers fetches follows that target given accountID.
GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// Unfollow removes a follow targeting targetAccountID and originating
// from originAccountID.
//
// If a follow was removed this way, the AP URI of the follow will be
// returned to the caller, so that further processing can take place
// if necessary.
//
// If no follow was removed this way, the returned string will be empty.
Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// UnfollowRequest removes a follow request targeting targetAccountID
// and originating from originAccountID.
//
// If a follow request was removed this way, the AP URI of the follow
// request will be returned to the caller, so that further processing
// can take place if necessary.
//
// If no follow request was removed this way, the returned string will
// be empty.
UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
// CountAccountFollowers returns the amounts that the given ID is followed by.
CountAccountFollowers(ctx context.Context, accountID string) (int, error)
// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
// GetAccountFollowRequesting returns all follow requests originating from the given account.
GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
// CountAccountFollowRequests returns number of follow requests targeting the given account.
CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
// CountAccountFollowerRequests returns number of follow requests originating from the given account.
CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
}

View File

@ -37,6 +37,9 @@ type Status interface {
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc).
PopulateStatus(ctx context.Context, status *gtsmodel.Status) error
// PutStatus stores one status in the database.
PutStatus(ctx context.Context, status *gtsmodel.Status) Error

View File

@ -24,22 +24,22 @@ import (
)
type StatusFave interface {
// GetStatusFave returns one status fave with the given id.
GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
// GetStatusFaveByAccountID gets one status fave created by the given
// accountID, targeting the given statusID.
GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
// GetStatusFave returns one status fave with the given id.
GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
// PutStatusFave inserts the given statusFave into the database.
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error
// DeleteStatusFave deletes one status fave with the given id.
DeleteStatusFave(ctx context.Context, id string) Error
DeleteStatusFaveByID(ctx context.Context, id string) Error
// DeleteStatusFaves mass deletes status faves targeting targetAccountID
// and/or originating from originAccountID and/or faving statusID.

View File

@ -383,7 +383,7 @@ func (d *deref) populateStatusMentions(ctx context.Context, status *gtsmodel.Sta
TargetAccountURL: targetAccount.URL,
}
if err := d.db.Put(ctx, newMention); err != nil {
if err := d.db.PutMention(ctx, newMention); err != nil {
return fmt.Errorf("populateStatusMentions: error creating mention: %s", err)
}

View File

@ -25,8 +25,6 @@ import (
"codeberg.org/gruf/go-logger/v2/level"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/uris"
@ -63,16 +61,16 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
acceptedObjectIRI := iter.GetIRI()
if uris.IsFollowPath(acceptedObjectIRI) {
// ACCEPT FOLLOW
gtsFollowRequest := &gtsmodel.FollowRequest{}
if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {
followReq, err := f.state.DB.GetFollowRequestByURI(ctx, acceptedObjectIRI.String())
if err != nil {
return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err)
}
// make sure the addressee of the original follow is the same as whatever inbox this landed in
if gtsFollowRequest.AccountID != receivingAccount.ID {
if followReq.AccountID != receivingAccount.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same")
}
follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)
follow, err := f.state.DB.AcceptFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID)
if err != nil {
return err
}

View File

@ -262,7 +262,7 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re
followRequest.ID = id.NewULID()
if err := f.state.DB.Put(ctx, followRequest); err != nil {
if err := f.state.DB.PutFollowRequest(ctx, followRequest); err != nil {
return fmt.Errorf("activityFollow: database error inserting follow request: %s", err)
}

View File

@ -38,7 +38,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
follows, err := f.state.DB.GetFollows(ctx, "", acct.ID)
follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID)
if err != nil {
return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err)
}

View File

@ -38,7 +38,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
follows, err := f.state.DB.GetFollows(ctx, acct.ID, "")
follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID)
if err != nil {
return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err)
}

View File

@ -89,7 +89,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err)
}
follows, err := f.state.DB.GetFollows(c, "", account.ID)
follows, err := f.state.DB.GetAccountFollowers(c, account.ID)
if err != nil {
return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err)
}

View File

@ -25,8 +25,6 @@ import (
"codeberg.org/gruf/go-logger/v2/level"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
@ -62,17 +60,17 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
rejectedObjectIRI := iter.GetIRI()
if uris.IsFollowPath(rejectedObjectIRI) {
// REJECT FOLLOW
gtsFollowRequest := &gtsmodel.FollowRequest{}
if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil {
followReq, err := f.state.DB.GetFollowRequestByURI(ctx, rejectedObjectIRI.String())
if err != nil {
return fmt.Errorf("Reject: couldn't get follow request with id %s from the database: %s", rejectedObjectIRI.String(), err)
}
// make sure the addressee of the original follow is the same as whatever inbox this landed in
if gtsFollowRequest.AccountID != receivingAccount.ID {
if followReq.AccountID != receivingAccount.ID {
return errors.New("Reject: follow object account and inbox account were not the same")
}
if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil {
if err := f.state.DB.RejectFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID); err != nil {
return err
}
@ -101,7 +99,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
if gtsFollow.AccountID != receivingAccount.ID {
return errors.New("Reject: follow object account and inbox account were not the same")
}
if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {
if err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {
return err
}

View File

@ -26,7 +26,6 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
@ -80,11 +79,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: follow object account and inbox account were not the same")
}
// delete any existing FOLLOW
if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.Follow{}); err != nil {
if err := f.state.DB.DeleteFollowByURI(ctx, gtsFollow.URI); err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("UNDO: db error removing follow: %s", err)
}
// delete any existing FOLLOW REQUEST
if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.FollowRequest{}); err != nil {
if err := f.state.DB.DeleteFollowRequestByURI(ctx, gtsFollow.URI); err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("UNDO: db error removing follow request: %s", err)
}
l.Debug("follow undone")

View File

@ -231,7 +231,7 @@ func (f *federatingDB) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (ac
// getAccountForIRI returns the account that corresponds to or owns the given IRI.
func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gtsmodel.Account, error) {
var (
acct = &gtsmodel.Account{}
acct *gtsmodel.Account
err error
)
@ -245,7 +245,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsInboxPath(iri):
if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil {
if acct, err = f.state.DB.GetAccountByInboxURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", iri.String())
}
@ -253,7 +253,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsOutboxPath(iri):
if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil {
if acct, err = f.state.DB.GetAccountByOutboxURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", iri.String())
}
@ -261,7 +261,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsFollowersPath(iri):
if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil {
if acct, err = f.state.DB.GetAccountByFollowersURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to followers_uri %s", iri.String())
}
@ -269,7 +269,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsFollowingPath(iri):
if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil {
if acct, err = f.state.DB.GetAccountByFollowingURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to following_uri %s", iri.String())
}

View File

@ -283,7 +283,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, errors.New("requesting account not set on request context, so couldn't determine blocks")
}
// the receiver shouldn't block the sender
blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, requestingAccount.ID, false)
blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, requestingAccount.ID)
if err != nil {
return false, fmt.Errorf("error checking user-level blocks: %s", err)
}
@ -309,7 +309,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
for _, involvedAccountID := range deduped {
// the involved account shouldn't block whoever is making this request
blocked, err = f.db.IsBlocked(ctx, involvedAccountID, requestingAccount.ID, false)
blocked, err = f.db.IsBlocked(ctx, involvedAccountID, requestingAccount.ID)
if err != nil {
return false, fmt.Errorf("error checking user-level otherInvolvedIRI blocks: %s", err)
}
@ -318,7 +318,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
}
// whoever is receiving this request shouldn't block the involved account
blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, involvedAccountID, false)
blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, involvedAccountID)
if err != nil {
return false, fmt.Errorf("error checking user-level otherInvolvedIRI blocks: %s", err)
}

View File

@ -0,0 +1,43 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtscontext
import "context"
// package private context key type.
type ctxkey uint
const (
// context keys.
_ ctxkey = iota
barebonesKey
)
// Barebones returns whether the "barebones" context key has been set. This
// can be used to indicate to the database, for example, that only a barebones
// model need be returned, Allowing it to skip populating sub models.
func Barebones(ctx context.Context) bool {
_, ok := ctx.Value(barebonesKey).(struct{})
return ok
}
// SetBarebones sets the "barebones" context flag and returns this wrapped context.
// See Barebones() for further information on the "barebones" context flag..
func SetBarebones(ctx context.Context) context.Context {
return context.WithValue(ctx, barebonesKey, struct{}{})
}

View File

@ -27,6 +27,7 @@ import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// Account represents either a local or a remote fediverse account, gotosocial or otherwise (mastodon, pleroma, etc).
@ -35,8 +36,8 @@ type Account struct {
CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created.
UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item was last updated.
FetchedAt time.Time `validate:"required_with=Domain" bun:"type:timestamptz,nullzero"` // when was item (remote) last fetched.
Username string `validate:"required" bun:",nullzero,notnull,unique:userdomain"` // Username of the account, should just be a string of [a-zA-Z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``. Username and domain should be unique *with* each other
Domain string `validate:"omitempty,fqdn" bun:",nullzero,unique:userdomain"` // Domain of the account, will be null if this is a local account, otherwise something like ``example.org``. Should be unique with username.
Username string `validate:"required" bun:",nullzero,notnull,unique:usernamedomain"` // Username of the account, should just be a string of [a-zA-Z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``. Username and domain should be unique *with* each other
Domain string `validate:"omitempty,fqdn" bun:",nullzero,unique:usernamedomain"` // Domain of the account, will be null if this is a local account, otherwise something like ``example.org``. Should be unique with username.
AvatarMediaAttachmentID string `validate:"omitempty,ulid" bun:"type:CHAR(26),nullzero"` // Database ID of the media attachment, if present
AvatarMediaAttachment *MediaAttachment `validate:"-" bun:"rel:belongs-to"` // MediaAttachment corresponding to avatarMediaAttachmentID
AvatarRemoteURL string `validate:"omitempty,url" bun:",nullzero"` // For a non-local account, where can the header be fetched?
@ -70,8 +71,8 @@ type Account struct {
FollowersURI string `validate:"required_without=Domain,omitempty,url" bun:",nullzero,unique"` // URI for getting the followers list of this account
FeaturedCollectionURI string `validate:"required_without=Domain,omitempty,url" bun:",nullzero,unique"` // URL for getting the featured collection list of this account
ActorType string `validate:"oneof=Application Group Organization Person Service" bun:",nullzero,notnull"` // What type of activitypub actor is this account?
PrivateKey *rsa.PrivateKey `validate:"required_without=Domain"` // Privatekey for validating activitypub requests, will only be defined for local accounts
PublicKey *rsa.PublicKey `validate:"required"` // Publickey for encoding activitypub requests, will be defined for both local and remote accounts
PrivateKey *rsa.PrivateKey `validate:"required_without=Domain" bun:""` // Privatekey for validating activitypub requests, will only be defined for local accounts
PublicKey *rsa.PublicKey `validate:"required" bun:",notnull,unique"` // Publickey for encoding activitypub requests, will be defined for both local and remote accounts
PublicKeyURI string `validate:"required,url" bun:",nullzero,notnull,unique"` // Web-reachable location of this account's public key
SensitizedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero"` // When was this account set to have all its media shown as sensitive?
SilencedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero"` // When was this account silenced (eg., statuses only visible to followers, not public)?
@ -82,23 +83,44 @@ type Account struct {
}
// IsLocal returns whether account is a local user account.
func (a Account) IsLocal() bool {
func (a *Account) IsLocal() bool {
return a.Domain == "" || a.Domain == config.GetHost() || a.Domain == config.GetAccountDomain()
}
// IsRemote returns whether account is a remote user account.
func (a Account) IsRemote() bool {
func (a *Account) IsRemote() bool {
return !a.IsLocal()
}
// IsInstance returns whether account is an instance internal actor account.
func (a Account) IsInstance() bool {
func (a *Account) IsInstance() bool {
return a.Username == a.Domain ||
a.FollowersURI == "" ||
a.FollowingURI == "" ||
(a.Username == "internal.fetch" && strings.Contains(a.Note, "internal service actor"))
}
// EmojisPopulated returns whether emojis are populated according to current EmojiIDs.
func (a *Account) EmojisPopulated() bool {
if len(a.EmojiIDs) != len(a.Emojis) {
// this is the quickest indicator.
return false
}
// Emojis must be in same order.
for i, id := range a.EmojiIDs {
if a.Emojis[i] == nil {
log.Warnf(nil, "nil emoji in slice for account %s", a.URI)
continue
}
if a.Emojis[i].ID != id {
return false
}
}
return true
}
// AccountToEmoji is an intermediate struct to facilitate the many2many relationship between an account and one or more emojis.
type AccountToEmoji struct {
AccountID string `validate:"ulid,required" bun:"type:CHAR(26),unique:accountemoji,nullzero,notnull"`

View File

@ -19,6 +19,8 @@ package gtsmodel
import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// Status represents a user-created 'post' or 'status' in the database, either remote or local
@ -65,27 +67,120 @@ type Status struct {
Likeable *bool `validate:"-" bun:",notnull"` // This status can be liked/faved
}
/*
The below functions are added onto the gtsmodel status so that it satisfies
the Timelineable interface in internal/timeline.
*/
// GetID implements timeline.Timelineable{}.
func (s *Status) GetID() string {
return s.ID
}
// GetAccountID implements timeline.Timelineable{}.
func (s *Status) GetAccountID() string {
return s.AccountID
}
// GetBoostID implements timeline.Timelineable{}.
func (s *Status) GetBoostOfID() string {
return s.BoostOfID
}
// GetBoostOfAccountID implements timeline.Timelineable{}.
func (s *Status) GetBoostOfAccountID() string {
return s.BoostOfAccountID
}
// AttachmentsPopulated returns whether media attachments are populated according to current AttachmentIDs.
func (s *Status) AttachmentsPopulated() bool {
if len(s.AttachmentIDs) != len(s.Attachments) {
// this is the quickest indicator.
return false
}
// Attachments must be in same order.
for i, id := range s.AttachmentIDs {
if s.Attachments[i] == nil {
log.Warnf(nil, "nil attachment in slice for status %s", s.URI)
continue
}
if s.Attachments[i].ID != id {
return false
}
}
return true
}
// TagsPopulated returns whether tags are populated according to current TagIDs.
func (s *Status) TagsPopulated() bool {
if len(s.TagIDs) != len(s.Tags) {
// this is the quickest indicator.
return false
}
// Tags must be in same order.
for i, id := range s.TagIDs {
if s.Tags[i] == nil {
log.Warnf(nil, "nil tag in slice for status %s", s.URI)
continue
}
if s.Tags[i].ID != id {
return false
}
}
return true
}
// MentionsPopulated returns whether mentions are populated according to current MentionIDs.
func (s *Status) MentionsPopulated() bool {
if len(s.MentionIDs) != len(s.Mentions) {
// this is the quickest indicator.
return false
}
// Mentions must be in same order.
for i, id := range s.MentionIDs {
if s.Mentions[i] == nil {
log.Warnf(nil, "nil mention in slice for status %s", s.URI)
continue
}
if s.Mentions[i].ID != id {
return false
}
}
return true
}
// EmojisPopulated returns whether emojis are populated according to current EmojiIDs.
func (s *Status) EmojisPopulated() bool {
if len(s.EmojiIDs) != len(s.Emojis) {
// this is the quickest indicator.
return false
}
// Emojis must be in same order.
for i, id := range s.EmojiIDs {
if s.Emojis[i] == nil {
log.Warnf(nil, "nil emoji in slice for status %s", s.URI)
continue
}
if s.Emojis[i].ID != id {
return false
}
}
return true
}
// MentionsAccount returns whether status mentions the given account ID.
func (s *Status) MentionsAccount(id string) bool {
for _, mention := range s.Mentions {
if mention.TargetAccountID == id {
return true
}
}
return false
}
// StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags.
type StatusToTag struct {
StatusID string `validate:"ulid,required" bun:"type:CHAR(26),unique:statustag,nullzero,notnull"`

View File

@ -36,7 +36,7 @@ type Processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
oauthServer oauth.Server
filter visibility.Filter
filter *visibility.Filter
formatter text.Formatter
federator federation.Federator
parseMention gtsmodel.ParseMentionFunc
@ -49,6 +49,7 @@ func New(
mediaManager media.Manager,
oauthServer oauth.Server,
federator federation.Federator,
filter *visibility.Filter,
parseMention gtsmodel.ParseMentionFunc,
) Processor {
return Processor{
@ -56,7 +57,7 @@ func New(
tc: tc,
mediaManager: mediaManager,
oauthServer: oauthServer,
filter: visibility.NewFilter(state.DB),
filter: filter,
formatter: text.NewFormatter(state.DB),
federator: federator,
parseMention: parseMention,

View File

@ -34,6 +34,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -101,7 +102,9 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))
filter := visibility.NewFilter(&suite.state)
suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
}

View File

@ -56,7 +56,7 @@ func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmode
return nil, gtserror.NewErrorInternalError(err) // A real error has occurred.
}
visible, err := p.filter.StatusVisible(ctx, status, requestingAccount)
visible, err := p.filter.StatusVisible(ctx, requestingAccount, status)
if err != nil {
log.Errorf(ctx, "error checking bookmarked status visibility: %s", err)
continue

View File

@ -150,25 +150,25 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *
// - Follow requests created by account.
func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error {
// Delete follows targeting this account.
followedBy, err := p.state.DB.GetFollows(ctx, "", account.ID)
followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follows targeting account %s: %w", account.ID, err)
}
for _, follow := range followedBy {
if _, err := p.state.DB.Unfollow(ctx, follow.AccountID, account.ID); err != nil {
if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowing account followedBy: %w", err)
}
}
// Delete follow requests targeting this account.
followRequestedBy, err := p.state.DB.GetFollowRequests(ctx, "", account.ID)
followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follow requests targeting account %s: %w", account.ID, err)
}
for _, followRequest := range followRequestedBy {
if _, err := p.state.DB.UnfollowRequest(ctx, followRequest.AccountID, account.ID); err != nil {
if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowing account followRequestedBy: %w", err)
}
}
@ -183,7 +183,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
)
// Delete follows originating from this account.
following, err := p.state.DB.GetFollows(ctx, account.ID, "")
following, err := p.state.DB.GetAccountFollows(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follows owned by account %s: %w", account.ID, err)
}
@ -191,15 +191,9 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
// For each follow owned by this account, unfollow
// and process side effects (noop if remote account).
for _, follow := range following {
if uri, err := p.state.DB.Unfollow(ctx, account.ID, follow.TargetAccountID); err != nil {
if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowing account: %w", err)
} else if uri == "" {
// There was no follow after all.
// Some race condition? Skip.
log.WithContext(ctx).WithField("follow", follow).Warn("Unfollow did not return uri, likely race condition")
continue
}
if msg := unfollowSideEffects(ctx, account, follow); msg != nil {
// There was a side effect to process.
msgs = append(msgs, *msg)
@ -207,7 +201,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
}
// Delete follow requests originating from this account.
followRequesting, err := p.state.DB.GetFollowRequests(ctx, account.ID, "")
followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follow requests owned by account %s: %w", account.ID, err)
}
@ -215,23 +209,15 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
// For each follow owned by this account, unfollow
// and process side effects (noop if remote account).
for _, followRequest := range followRequesting {
uri, err := p.state.DB.UnfollowRequest(ctx, account.ID, followRequest.TargetAccountID)
if err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowRequesting account: %w", err)
}
if uri == "" {
// There was no follow request after all.
// Some race condition? Skip.
log.WithContext(ctx).WithField("followRequest", followRequest).Warn("UnfollowRequest did not return uri, likely race condition")
continue
if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowingRequesting account: %w", err)
}
// Dummy out a follow so our side effects func
// has something to work with. This follow will
// never enter the db, it's just for convenience.
follow := &gtsmodel.Follow{
URI: uri,
URI: followRequest.URI,
AccountID: followRequest.AccountID,
Account: followRequest.Account,
TargetAccountID: followRequest.TargetAccountID,
@ -284,16 +270,9 @@ func (p *Processor) unfollowSideEffectsFunc(deletedAccount *gtsmodel.Account) fu
}
func (p *Processor) deleteAccountBlocks(ctx context.Context, account *gtsmodel.Account) error {
// Delete blocks created by this account.
if err := p.state.DB.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {
return fmt.Errorf("deleteAccountBlocks: db error deleting blocks created by account %s: %w", account.ID, err)
if err := p.state.DB.DeleteAccountBlocks(ctx, account.ID); err != nil {
return fmt.Errorf("deleteAccountBlocks: db error deleting account blocks for %s: %w", account.ID, err)
}
// Delete blocks targeting this account.
if err := p.state.DB.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {
return fmt.Errorf("deleteAccountBlocks: db error deleting blocks targeting account %s: %w", account.ID, err)
}
return nil
}
@ -386,13 +365,13 @@ statusLoop:
}
func (p *Processor) deleteAccountNotifications(ctx context.Context, account *gtsmodel.Account) error {
// Delete all notifications targeting given account.
if err := p.state.DB.DeleteNotifications(ctx, account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
// Delete all notifications of all types targeting given account.
if err := p.state.DB.DeleteNotifications(ctx, nil, account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Delete all notifications originating from given account.
if err := p.state.DB.DeleteNotifications(ctx, "", account.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
// Delete all notifications of all types originating from given account.
if err := p.state.DB.DeleteNotifications(ctx, nil, "", account.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}

View File

@ -40,7 +40,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// Check if a follow exists already.
if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount, targetAccount); err != nil {
if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
err = fmt.Errorf("FollowCreate: db error checking follow: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if follows {
@ -49,7 +49,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// Check if a follow request exists already.
if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount, targetAccount); err != nil {
if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
err = fmt.Errorf("FollowCreate: db error checking follow request: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if followRequested {
@ -75,7 +75,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
Notify: form.Notify,
}
if err := p.state.DB.Put(ctx, fr); err != nil {
if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil {
err = fmt.Errorf("FollowCreate: error creating follow request in db: %s", err)
return nil, gtserror.NewErrorInternalError(err)
}
@ -141,7 +141,7 @@ func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID str
}
// Do nothing if a block exists in either direction between accounts.
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccountID, targetAccountID, true); err != nil {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, targetAccountID); err != nil {
err = fmt.Errorf("db error checking block between accounts: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@ -173,12 +173,30 @@ func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID str
// messages will be returned which should then be processed by a client
// api worker.
func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) ([]messages.FromClientAPI, error) {
msgs := []messages.FromClientAPI{}
var msgs []messages.FromClientAPI
if fURI, err := p.state.DB.Unfollow(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
err = fmt.Errorf("unfollow: error deleting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
// Get follow from requesting account to target account.
follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
} else if fURI != "" {
}
if follow != nil {
// Delete known follow from database with ID.
err = p.state.DB.DeleteFollowByID(ctx, follow.ID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
}
// If err == db.ErrNoEntries here then it
// indicates a race condition with another
// unfollow for the same requester->target.
return msgs, nil
}
// Follow status changed, process side effects.
msgs = append(msgs, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
@ -186,25 +204,43 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
GTSModel: &gtsmodel.Follow{
AccountID: requestingAccount.ID,
TargetAccountID: targetAccount.ID,
URI: fURI,
URI: follow.URI,
},
OriginAccount: requestingAccount,
TargetAccount: targetAccount,
})
}
if frURI, err := p.state.DB.UnfollowRequest(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
// Get follow request from requesting account to target account.
followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
}
if followReq != nil {
// Delete known follow request from database with ID.
err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
} else if frURI != "" {
// Follow request status changed, process side effects.
}
// If err == db.ErrNoEntries here then it
// indicates a race condition with another
// unfollow for the same requester->target.
return msgs, nil
}
// Follow status changed, process side effects.
msgs = append(msgs, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo,
GTSModel: &gtsmodel.Follow{
AccountID: requestingAccount.ID,
TargetAccountID: targetAccount.ID,
URI: frURI,
URI: followReq.URI,
},
OriginAccount: requestingAccount,
TargetAccount: targetAccount,

View File

@ -73,7 +73,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco
var blocked bool
var err error
if requestingAccount != nil {
blocked, err = p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)
blocked, err = p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err))
}

View File

@ -31,7 +31,7 @@ import (
// FollowersGet fetches a list of the target account's followers.
func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
err = fmt.Errorf("FollowersGet: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@ -39,7 +39,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
return nil, gtserror.NewErrorNotFound(err)
}
follows, err := p.state.DB.GetFollows(ctx, "", targetAccountID)
follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FollowersGet: db error getting followers: %w", err)
@ -53,7 +53,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
// FollowingGet fetches a list of the accounts that target account is following.
func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
err = fmt.Errorf("FollowingGet: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@ -61,7 +61,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
return nil, gtserror.NewErrorNotFound(err)
}
follows, err := p.state.DB.GetFollows(ctx, targetAccountID, "")
follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FollowingGet: db error getting followers: %w", err)
@ -70,7 +70,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
return []apimodel.Account{}, nil
}
return p.accountsFromFollows(ctx, follows, requestingAccount.ID)
return p.targetAccountsFromFollows(ctx, follows, requestingAccount.ID)
}
// RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
@ -101,7 +101,7 @@ func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel
continue
}
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccountID, follow.AccountID, true); err != nil {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.AccountID); err != nil {
err = fmt.Errorf("accountsFromFollows: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@ -113,8 +113,35 @@ func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel
err = fmt.Errorf("accountsFromFollows: error converting account to api account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
accounts = append(accounts, *account)
}
return accounts, nil
}
func (p *Processor) targetAccountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) {
accounts := make([]apimodel.Account, 0, len(follows))
for _, follow := range follows {
if follow.TargetAccount == nil {
// No account set for some reason; just skip.
log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated target account")
continue
}
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.TargetAccountID); err != nil {
err = fmt.Errorf("targetAccountsFromFollows: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
continue
}
account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.TargetAccount)
if err != nil {
err = fmt.Errorf("targetAccountsFromFollows: error converting account to api account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
accounts = append(accounts, *account)
}
return accounts, nil
}

View File

@ -19,6 +19,7 @@ package account
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@ -32,10 +33,11 @@ import (
// the account given in authed.
func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) {
if requestingAccount != nil {
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
err := errors.New("block exists between accounts")
return nil, gtserror.NewErrorNotFound(err)
}
}
@ -57,14 +59,10 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
return nil, gtserror.NewErrorInternalError(err)
}
// Filtering + serialization process is the same for
// either pinned status queries or 'normal' ones.
filtered := make([]*gtsmodel.Status, 0, len(statuses))
for _, s := range statuses {
visible, err := p.filter.StatusVisible(ctx, s, requestingAccount)
if err == nil && visible {
filtered = append(filtered, s)
}
// Filtering + serialization process is the same for either pinned status queries or 'normal' ones.
filtered, err := p.filter.StatusesVisible(ctx, requestingAccount, statuses)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
count := len(filtered)

View File

@ -45,7 +45,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)
return
}
blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
blocked, err := p.state.DB.IsEitherBlocked(ctx, requestedAccount.ID, requestingAccount.ID)
if err != nil {
errWithCode = gtserror.NewErrorInternalError(err)
return

View File

@ -28,15 +28,15 @@ type Processor struct {
state *state.State
federator federation.Federator
tc typeutils.TypeConverter
filter visibility.Filter
filter *visibility.Filter
}
// New returns a new fedi processor.
func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator) Processor {
func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator, filter *visibility.Filter) Processor {
return Processor{
state: state,
federator: federator,
tc: tc,
filter: visibility.NewFilter(state.DB),
filter: filter,
}
}

View File

@ -44,7 +44,7 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req
return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID))
}
visible, err := p.filter.StatusVisible(ctx, status, requestingAccount)
visible, err := p.filter.StatusVisible(ctx, requestingAccount, status)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@ -82,7 +82,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID))
}
visible, err := p.filter.StatusVisible(ctx, status, requestingAccount)
visible, err := p.filter.StatusVisible(ctx, requestedAccount, status)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@ -143,13 +143,13 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
}
// only show replies that the status owner can see
visibleToStatusOwner, err := p.filter.StatusVisible(ctx, r, requestedAccount)
visibleToStatusOwner, err := p.filter.StatusVisible(ctx, requestedAccount, r)
if err != nil || !visibleToStatusOwner {
continue
}
// only show replies that the requester can see
visibleToRequester, err := p.filter.StatusVisible(ctx, r, requestingAccount)
visibleToRequester, err := p.filter.StatusVisible(ctx, requestingAccount, r)
if err != nil || !visibleToRequester {
continue
}

View File

@ -62,7 +62,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque
return nil, gtserror.NewErrorUnauthorized(err)
}
blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
blocked, err := p.state.DB.IsEitherBlocked(ctx, requestedAccount.ID, requestingAccount.ID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}

View File

@ -31,7 +31,7 @@ import (
)
func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
followRequests, err := p.state.DB.GetFollowRequests(ctx, "", auth.Account.ID)
followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
@ -49,8 +49,10 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
accts = append(accts, *apiAcct)
}
return accts, nil
}
@ -79,7 +81,12 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
}
func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
followRequest, err := p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
followRequest, err := p.state.DB.GetFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
err = p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}

View File

@ -39,11 +39,12 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
if status.Mentions == nil {
// there are mentions but they're not fully populated on the status yet so do this
menchies, err := p.state.DB.GetMentions(ctx, status.MentionIDs)
mentions, err := p.state.DB.GetMentions(ctx, status.MentionIDs)
if err != nil {
return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)
}
status.Mentions = menchies
status.Mentions = mentions
}
// now we have mentions as full gtsmodel.Mention structs on the status we can continue
@ -88,7 +89,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
Status: status,
}
if err := p.state.DB.Put(ctx, notif); err != nil {
if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)
}
@ -130,7 +131,7 @@ func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsm
OriginAccountID: followRequest.AccountID,
}
if err := p.state.DB.Put(ctx, notif); err != nil {
if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)
}
@ -171,7 +172,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
OriginAccountID: follow.AccountID,
OriginAccount: follow.Account,
}
if err := p.state.DB.Put(ctx, notif); err != nil {
if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)
}
@ -219,7 +220,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
Status: fave.Status,
}
if err := p.state.DB.Put(ctx, notif); err != nil {
if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyFave: error putting notification in database: %s", err)
}
@ -293,7 +294,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
Status: status,
}
if err := p.state.DB.Put(ctx, notif); err != nil {
if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)
}
@ -403,39 +404,39 @@ func (p *Processor) notifyReportClosed(ctx context.Context, report *gtsmodel.Rep
// timelineStatus processes the given new status and inserts it into
// the HOME timelines of accounts that follow the status author.
func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {
// make sure the author account is pinned onto the status
if status.Account == nil {
a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)
// ensure status fully populated (including account)
if err := p.state.DB.PopulateStatus(ctx, status); err != nil {
return fmt.Errorf("timelineStatus: error populating status with id %s: %w", status.ID, err)
}
status.Account = a
}
// Get LOCAL followers of the account that posted the status;
// we know that remote accounts don't have timelines on this
// instance, so there's no point selecting them too.
accountIDs, err := p.state.DB.GetLocalFollowersIDs(ctx, status.AccountID)
// get local followers of the account that posted the status
follows, err := p.state.DB.GetAccountLocalFollowers(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)
return fmt.Errorf("timelineStatus: error getting followers for account id %s: %w", status.AccountID, err)
}
// If the poster is also local, add a fake entry for them
// so they can see their own status in their timeline.
if status.Account.IsLocal() {
accountIDs = append(accountIDs, status.AccountID)
follows = append(follows, &gtsmodel.Follow{
AccountID: status.AccountID,
Account: status.Account,
})
}
var errs gtserror.MultiError
for _, follow := range follows {
// Timeline the status for each local following account.
errors := gtserror.MultiError{}
for _, accountID := range accountIDs {
if err := p.timelineStatusForAccount(ctx, status, accountID); err != nil {
errors.Append(err)
if err := p.timelineStatusForAccount(ctx, follow.Account, status); err != nil {
errs.Append(err)
}
}
if len(errors) != 0 {
return fmt.Errorf("timelineStatus: one or more errors timelining statuses: %w", errors.Combine())
if len(errs) != 0 {
return fmt.Errorf("timelineStatus: one or more errors timelining statuses: %w", errs.Combine())
}
return nil
@ -446,34 +447,28 @@ func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status)
//
// If the status was inserted into the home timeline of the given account,
// it will also be streamed via websockets to the user.
func (p *Processor) timelineStatusForAccount(ctx context.Context, status *gtsmodel.Status, accountID string) error {
// get the timeline owner account
timelineAccount, err := p.state.DB.GetAccountByID(ctx, accountID)
if err != nil {
return fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %w", accountID, err)
}
func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmodel.Account, status *gtsmodel.Status) error {
// make sure the status is timelineable
if timelineable, err := p.filter.StatusHometimelineable(ctx, status, timelineAccount); err != nil {
return fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", accountID, err)
if timelineable, err := p.filter.StatusHomeTimelineable(ctx, account, status); err != nil {
return fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", account.ID, err)
} else if !timelineable {
return nil
}
// stick the status in the timeline for the account and then immediately prepare it so they can see it right away
if inserted, err := p.statusTimelines.IngestAndPrepare(ctx, status, timelineAccount.ID); err != nil {
if inserted, err := p.statusTimelines.IngestAndPrepare(ctx, status, account.ID); err != nil {
return fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %w", status.ID, err)
} else if !inserted {
return nil
}
// the status was inserted so stream it to the user
apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, timelineAccount)
apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, account)
if err != nil {
return fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %w", status.ID, err)
}
if err := p.stream.Update(apiStatus, timelineAccount, stream.TimelineHome); err != nil {
if err := p.stream.Update(apiStatus, account, stream.TimelineHome); err != nil {
return fmt.Errorf("timelineStatusForAccount: error streaming update for status %s: %w", status.ID, err)
}
@ -513,8 +508,8 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
}
// delete all mention entries generated by this status
for _, m := range statusToDelete.MentionIDs {
if err := p.state.DB.DeleteByID(ctx, m, &gtsmodel.Mention{}); err != nil {
for _, id := range statusToDelete.MentionIDs {
if err := p.state.DB.DeleteMentionByID(ctx, id); err != nil {
return err
}
}

View File

@ -358,10 +358,10 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {
suite.ErrorIs(err, db.ErrNoEntries)
// the mufos should be gone now too
satanFollowsZork, err := suite.db.IsFollowing(ctx, deletedAccount, receivingAccount)
satanFollowsZork, err := suite.db.IsFollowing(ctx, deletedAccount.ID, receivingAccount.ID)
suite.NoError(err)
suite.False(satanFollowsZork)
zorkFollowsSatan, err := suite.db.IsFollowing(ctx, receivingAccount, deletedAccount)
zorkFollowsSatan, err := suite.db.IsFollowing(ctx, receivingAccount.ID, deletedAccount.ID)
suite.NoError(err)
suite.False(zorkFollowsSatan)

View File

@ -63,7 +63,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
// make sure the requesting account and the media account don't block each other
if requestingAccount != nil {
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)
blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, owningAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err))
}

View File

@ -30,7 +30,7 @@ import (
)
func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) {
notifs, err := p.state.DB.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
notifs, err := p.state.DB.GetAccountNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@ -73,8 +73,8 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ex
}
func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode {
// Delete all notifications that target the authorized account.
if err := p.state.DB.DeleteNotifications(ctx, authed.Account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
// Delete all notifications of all types that target the authorized account.
if err := p.state.DB.DeleteNotifications(ctx, nil, authed.Account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.NewErrorInternalError(err)
}

View File

@ -47,8 +47,8 @@ type Processor struct {
mediaManager mm.Manager
statusTimelines timeline.Manager
state *state.State
filter visibility.Filter
emailSender email.Sender
filter *visibility.Filter
/*
SUB-PROCESSORS
@ -107,7 +107,7 @@ func NewProcessor(
) *Processor {
parseMentionFunc := GetParseMentionFunc(state.DB, federator)
filter := visibility.NewFilter(state.DB)
filter := visibility.NewFilter(state)
processor := &Processor{
federator: federator,
@ -126,12 +126,12 @@ func NewProcessor(
}
// sub processors
processor.account = account.New(state, tc, mediaManager, oauthServer, federator, parseMentionFunc)
processor.account = account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc)
processor.admin = admin.New(state, tc, mediaManager, federator.TransportController(), emailSender)
processor.fedi = fedi.New(state, tc, federator)
processor.fedi = fedi.New(state, tc, federator, filter)
processor.media = media.New(state, tc, mediaManager, federator.TransportController())
processor.report = report.New(state, tc)
processor.status = status.New(state, tc, parseMentionFunc)
processor.status = status.New(state, tc, filter, parseMentionFunc)
processor.stream = stream.New(state, oauthServer)
processor.user = user.New(state, emailSender)
@ -139,22 +139,24 @@ func NewProcessor(
}
func (p *Processor) EnqueueClientAPI(ctx context.Context, msgs ...messages.FromClientAPI) {
log.Trace(ctx, "enqueuing client API")
log.Trace(ctx, "enqueuing")
_ = p.state.Workers.ClientAPI.MustEnqueueCtx(ctx, func(ctx context.Context) {
for _, msg := range msgs {
log.Trace(ctx, "processing: %+v", msg)
if err := p.ProcessFromClientAPI(ctx, msg); err != nil {
log.WithContext(ctx).WithField("msg", msg).Errorf("error processing client API message: %v", err)
log.Errorf(ctx, "error processing client API message: %v", err)
}
}
})
}
func (p *Processor) EnqueueFederator(ctx context.Context, msgs ...messages.FromFederator) {
log.Trace(ctx, "enqueuing federator")
log.Trace(ctx, "enqueuing")
_ = p.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) {
for _, msg := range msgs {
log.Trace(ctx, "processing: %+v", msg)
if err := p.ProcessFromFederator(ctx, msg); err != nil {
log.WithContext(ctx).WithField("msg", msg).Errorf("error processing federator message: %v", err)
log.Errorf(ctx, "error processing federator message: %v", err)
}
}
})

View File

@ -177,7 +177,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
*/
for _, foundAccount := range foundAccounts {
// make sure there's no block in either direction between the account and the requester
blocked, err := p.state.DB.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)
blocked, err := p.state.DB.IsEitherBlocked(ctx, authed.Account.ID, foundAccount.ID)
if err != nil {
err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err)
return nil, gtserror.NewErrorInternalError(err)
@ -199,7 +199,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
for _, foundStatus := range foundStatuses {
// make sure each found status is visible to the requester
visible, err := p.filter.StatusVisible(ctx, foundStatus, authed.Account)
visible, err := p.filter.StatusVisible(ctx, authed.Account, foundStatus)
if err != nil {
err = fmt.Errorf("SearchGet: error checking visibility of status %s for account %s: %s", foundStatus.ID, authed.Account.ID, err)
return nil, gtserror.NewErrorInternalError(err)

View File

@ -55,12 +55,11 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
targetStatus = targetStatus.BoostOf
}
boostable, err := p.filter.StatusBoostable(ctx, targetStatus, requestingAccount)
boostable, err := p.filter.StatusBoostable(ctx, requestingAccount, targetStatus)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is boostable: %s", targetStatus.ID, err))
}
if !boostable {
return nil, gtserror.NewErrorForbidden(errors.New("status is not boostable"))
} else if !boostable {
return nil, gtserror.NewErrorNotFound(errors.New("status is not boostable"))
}
// it's visible! it's boostable! so let's boost the FUCK out of it
@ -99,7 +98,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
visible, err := p.filter.StatusVisible(ctx, requestingAccount, targetStatus)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@ -180,7 +179,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
targetStatus = boostedStatus
}
visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
visible, err := p.filter.StatusVisible(ctx, requestingAccount, targetStatus)
if err != nil {
err = fmt.Errorf("BoostedBy: error seeing if status %s is visible: %s", targetStatus.ID, err)
return nil, gtserror.NewErrorNotFound(err)
@ -199,7 +198,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
// filter account IDs so the user doesn't see accounts they blocked or which blocked them
accountIDs := make([]string, 0, len(statusReblogs))
for _, s := range statusReblogs {
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, s.AccountID)
if err != nil {
err = fmt.Errorf("BoostedBy: error checking blocks: %s", err)
return nil, gtserror.NewErrorNotFound(err)

View File

@ -43,12 +43,7 @@ func (p *Processor) getVisibleStatus(ctx context.Context, requestingAccount *gts
return nil, gtserror.NewErrorNotFound(err)
}
if targetStatus.Account == nil {
err = fmt.Errorf("getVisibleStatus: no status owner for status %s", targetStatusID)
return nil, gtserror.NewErrorNotFound(err)
}
visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
visible, err := p.filter.StatusVisible(ctx, requestingAccount, targetStatus)
if err != nil {
err = fmt.Errorf("getVisibleStatus: error seeing if status %s is visible: %w", targetStatus.ID, err)
return nil, gtserror.NewErrorNotFound(err)

View File

@ -133,7 +133,7 @@ func processReplyToID(ctx context.Context, dbService db.DB, form *apimodel.Advan
return gtserror.NewErrorInternalError(err)
}
if blocked, err := dbService.IsBlocked(ctx, thisAccountID, repliedAccount.ID, true); err != nil {
if blocked, err := dbService.IsEitherBlocked(ctx, thisAccountID, repliedAccount.ID); err != nil {
err := fmt.Errorf("db error checking block: %s", err)
return gtserror.NewErrorInternalError(err)
} else if blocked {

View File

@ -88,7 +88,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
}
// We have a fave to remove.
if err := p.state.DB.DeleteStatusFave(ctx, existingFave.ID); err != nil {
if err := p.state.DB.DeleteStatusFaveByID(ctx, existingFave.ID); err != nil {
err = fmt.Errorf("FaveRemove: error removing status fave: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
@ -112,7 +112,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
return nil, errWithCode
}
statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus.ID)
statusFaves, err := p.state.DB.GetStatusFavesForStatus(ctx, targetStatus.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("FavedBy: error seeing who faved status: %s", err))
}
@ -122,7 +122,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
// and which don't block them.
apiAccounts := make([]*apimodel.Account, 0, len(statusFaves))
for _, fave := range statusFaves {
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true); err != nil {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, fave.AccountID); err != nil {
err = fmt.Errorf("FavedBy: error checking blocks: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@ -157,7 +157,7 @@ func (p *Processor) getFaveTarget(ctx context.Context, requestingAccount *gtsmod
return nil, nil, gtserror.NewErrorForbidden(err, err.Error())
}
fave, err := p.state.DB.GetStatusFaveByAccountID(ctx, requestingAccount.ID, targetStatusID)
fave, err := p.state.DB.GetStatusFave(ctx, requestingAccount.ID, targetStatusID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("getFaveTarget: error checking existing fave: %w", err)
return nil, nil, gtserror.NewErrorInternalError(err)

View File

@ -54,7 +54,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
}
for _, status := range parents {
if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v {
if v, err := p.filter.StatusVisible(ctx, requestingAccount, status); err == nil && v {
apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, requestingAccount)
if err == nil {
context.Ancestors = append(context.Ancestors, *apiStatus)
@ -72,7 +72,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
}
for _, status := range children {
if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v {
if v, err := p.filter.StatusVisible(ctx, requestingAccount, status); err == nil && v {
apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, requestingAccount)
if err == nil {
context.Descendants = append(context.Descendants, *apiStatus)

View File

@ -28,17 +28,17 @@ import (
type Processor struct {
state *state.State
tc typeutils.TypeConverter
filter visibility.Filter
filter *visibility.Filter
formatter text.Formatter
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new status processor.
func New(state *state.State, tc typeutils.TypeConverter, parseMention gtsmodel.ParseMentionFunc) Processor {
func New(state *state.State, tc typeutils.TypeConverter, filter *visibility.Filter, parseMention gtsmodel.ParseMentionFunc) Processor {
return Processor{
state: state,
tc: tc,
filter: visibility.NewFilter(state.DB),
filter: filter,
formatter: text.NewFormatter(state.DB),
parseMention: parseMention,
}

View File

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -85,7 +86,9 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager)
suite.status = status.New(&suite.state, suite.typeConverter, processing.GetParseMentionFunc(suite.db, suite.federator))
filter := visibility.NewFilter(&suite.state)
suite.status = status.New(&suite.state, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")

View File

@ -47,9 +47,9 @@ func StatusGrabFunction(database db.DB) timeline.GrabFunction {
return nil, false, fmt.Errorf("statusGrabFunction: error getting statuses from db: %s", err)
}
items := []timeline.Timelineable{}
for _, s := range statuses {
items = append(items, s)
items := make([]timeline.Timelineable, len(statuses))
for i, s := range statuses {
items[i] = s
}
return items, false, nil
@ -57,7 +57,7 @@ func StatusGrabFunction(database db.DB) timeline.GrabFunction {
}
// StatusFilterFunction returns a function that satisfies the FilterFunction interface in internal/timeline.
func StatusFilterFunction(database db.DB, filter visibility.Filter) timeline.FilterFunction {
func StatusFilterFunction(database db.DB, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, timelineAccountID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
@ -69,7 +69,7 @@ func StatusFilterFunction(database db.DB, filter visibility.Filter) timeline.Fil
return false, fmt.Errorf("statusFilterFunction: error getting account with id %s", timelineAccountID)
}
timelineable, err := filter.StatusHometimelineable(ctx, status, requestingAccount)
timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
log.Warnf(ctx, "error checking hometimelineability of status %s for account %s: %s", status.ID, timelineAccountID, err)
}
@ -253,8 +253,7 @@ func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, ma
func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
targetAccount := &gtsmodel.Account{}
if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
if _, err := p.state.DB.GetAccountByID(ctx, s.AccountID); err != nil {
if err == db.ErrNoEntries {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
@ -262,7 +261,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth
return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err))
}
timelineable, err := p.filter.StatusPublictimelineable(ctx, s, authed.Account)
timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
@ -286,8 +285,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth
func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
targetAccount := &gtsmodel.Account{}
if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
if _, err := p.state.DB.GetAccountByID(ctx, s.AccountID); err != nil {
if err == db.ErrNoEntries {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
@ -295,7 +293,7 @@ func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth,
return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err))
}
timelineable, err := p.filter.StatusVisible(ctx, s, authed.Account)
timelineable, err := p.filter.StatusVisible(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue

View File

@ -240,7 +240,7 @@ func (r *customRenderer) renderMention(w mdutil.BufWriter, source []byte, node a
n, ok := node.(*mention) // this function is only registered for kindMention
if !ok {
log.Panic(nil, "type assertion failed")
log.Panic(r.ctx, "type assertion failed")
}
text := string(n.Segment.Value(source))
@ -248,7 +248,7 @@ func (r *customRenderer) renderMention(w mdutil.BufWriter, source []byte, node a
// we don't have much recourse if this fails
if _, err := w.WriteString(html); err != nil {
log.Errorf(nil, "error writing HTML: %s", err)
log.Errorf(r.ctx, "error writing HTML: %s", err)
}
return ast.WalkSkipChildren, nil
}
@ -260,7 +260,7 @@ func (r *customRenderer) renderHashtag(w mdutil.BufWriter, source []byte, node a
n, ok := node.(*hashtag) // this function is only registered for kindHashtag
if !ok {
log.Panic(nil, "type assertion failed")
log.Panic(r.ctx, "type assertion failed")
}
text := string(n.Segment.Value(source))
@ -269,7 +269,7 @@ func (r *customRenderer) renderHashtag(w mdutil.BufWriter, source []byte, node a
_, err := w.WriteString(html)
// we don't have much recourse if this fails
if err != nil {
log.Errorf(nil, "error writing HTML: %s", err)
log.Errorf(r.ctx, "error writing HTML: %s", err)
}
return ast.WalkSkipChildren, nil
}
@ -282,7 +282,7 @@ func (r *customRenderer) renderEmoji(w mdutil.BufWriter, source []byte, node ast
n, ok := node.(*emoji) // this function is only registered for kindEmoji
if !ok {
log.Panic(nil, "type assertion failed")
log.Panic(r.ctx, "type assertion failed")
}
text := string(n.Segment.Value(source))
shortcode := text[1 : len(text)-1]
@ -307,7 +307,7 @@ func (r *customRenderer) renderEmoji(w mdutil.BufWriter, source []byte, node ast
// we don't have much recourse if this fails
if _, err := w.WriteString(text); err != nil {
log.Errorf(nil, "error writing HTML: %s", err)
log.Errorf(r.ctx, "error writing HTML: %s", err)
}
return ast.WalkSkipChildren, nil
}

View File

@ -22,6 +22,7 @@ import (
"strings"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/text/unicode/norm"
@ -37,15 +38,15 @@ const (
// replaceMention takes a string in the form @username@domain.com or @localusername
func (r *customRenderer) replaceMention(text string) string {
menchie, err := r.parseMention(r.ctx, text, r.accountID, r.statusID)
mention, err := r.parseMention(r.ctx, text, r.accountID, r.statusID)
if err != nil {
log.Errorf(nil, "error parsing mention %s from status: %s", text, err)
log.Errorf(r.ctx, "error parsing mention %s from status: %s", text, err)
return text
}
if r.statusID != "" {
if err := r.f.db.Put(r.ctx, menchie); err != nil {
log.Errorf(nil, "error putting mention in db: %s", err)
if err := r.f.db.PutMention(r.ctx, mention); err != nil {
log.Errorf(r.ctx, "error putting mention in db: %s", err)
return text
}
}
@ -53,27 +54,29 @@ func (r *customRenderer) replaceMention(text string) string {
// only append if it's not been listed yet
listed := false
for _, m := range r.result.Mentions {
if menchie.ID == m.ID {
if mention.ID == m.ID {
listed = true
break
}
}
if !listed {
r.result.Mentions = append(r.result.Mentions, menchie)
r.result.Mentions = append(r.result.Mentions, mention)
}
// make sure we have an account attached to this mention
if menchie.TargetAccount == nil {
a, err := r.f.db.GetAccountByID(r.ctx, menchie.TargetAccountID)
if mention.TargetAccount == nil {
// Fetch mention target account if not yet populated.
mention.TargetAccount, err = r.f.db.GetAccountByID(
gtscontext.SetBarebones(r.ctx),
mention.TargetAccountID,
)
if err != nil {
log.Errorf(nil, "error getting account with id %s from the db: %s", menchie.TargetAccountID, err)
log.Errorf(r.ctx, "error populating mention target account: %v", err)
return text
}
menchie.TargetAccount = a
}
// The mention's target is our target
targetAccount := menchie.TargetAccount
targetAccount := mention.TargetAccount
var b strings.Builder
@ -105,7 +108,7 @@ func (r *customRenderer) replaceHashtag(text string) string {
tag, err := r.f.db.TagStringToTag(r.ctx, normalized, r.accountID)
if err != nil {
log.Errorf(nil, "error generating hashtags from status: %s", err)
log.Errorf(r.ctx, "error generating hashtags from status: %s", err)
return text
}
@ -121,7 +124,7 @@ func (r *customRenderer) replaceHashtag(text string) string {
err = r.f.db.Put(r.ctx, tag)
if err != nil {
if !errors.Is(err, db.ErrAlreadyExists) {
log.Errorf(nil, "error putting tags in db: %s", err)
log.Errorf(r.ctx, "error putting tags in db: %s", err)
return text
}
}

View File

@ -26,7 +26,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@ -42,15 +41,14 @@ func (suite *GetTestSuite) SetupSuite() {
}
func (suite *GetTestSuite) SetupTest() {
var state state.State
state.Caches.Init()
suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
suite.db = testrig.NewTestDB(&state)
suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)

View File

@ -25,7 +25,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@ -41,15 +40,14 @@ func (suite *IndexTestSuite) SetupSuite() {
}
func (suite *IndexTestSuite) SetupTest() {
var state state.State
state.Caches.Init()
suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
suite.db = testrig.NewTestDB(&state)
suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)

View File

@ -23,7 +23,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@ -39,15 +38,14 @@ func (suite *ManagerTestSuite) SetupSuite() {
}
func (suite *ManagerTestSuite) SetupTest() {
var state state.State
state.Caches.Init()
suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
suite.db = testrig.NewTestDB(&state)
suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)

View File

@ -25,7 +25,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@ -41,15 +40,14 @@ func (suite *PruneTestSuite) SetupSuite() {
}
func (suite *PruneTestSuite) SetupTest() {
var state state.State
state.Caches.Init()
suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
suite.db = testrig.NewTestDB(&state)
suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)

View File

@ -21,6 +21,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
@ -29,8 +30,9 @@ import (
type TimelineStandardTestSuite struct {
suite.Suite
db db.DB
state state.State
tc typeutils.TypeConverter
filter visibility.Filter
filter *visibility.Filter
testAccounts map[string]*gtsmodel.Account
testStatuses map[string]*gtsmodel.Status

View File

@ -470,6 +470,7 @@ const (
type TypeUtilsTestSuite struct {
suite.Suite
db db.DB
state state.State
testAccounts map[string]*gtsmodel.Account
testStatuses map[string]*gtsmodel.Status
testAttachments map[string]*gtsmodel.MediaAttachment
@ -482,13 +483,12 @@ type TypeUtilsTestSuite struct {
}
func (suite *TypeUtilsTestSuite) SetupSuite() {
var state state.State
state.Caches.Init()
suite.state.Caches.Init()
testrig.InitTestConfig()
testrig.InitTestLog()
suite.db = testrig.NewTestDB(&state)
suite.db = testrig.NewTestDB(&suite.state)
suite.testAccounts = testrig.NewTestAccounts()
suite.testStatuses = testrig.NewTestStatuses()
suite.testAttachments = testrig.NewTestAttachments()
@ -500,6 +500,7 @@ func (suite *TypeUtilsTestSuite) SetupSuite() {
}
func (suite *TypeUtilsTestSuite) SetupTest() {
suite.state.Caches.Init() // reset
testrig.StandardDBSetup(suite.db, nil)
}

View File

@ -59,7 +59,7 @@ func (c *converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode
// then adding the Source object to it...
// check pending follow requests aimed at this account
frc, err := c.db.CountFollowRequests(ctx, "", a.ID)
frc, err := c.db.CountAccountFollowRequests(ctx, a.ID)
if err != nil {
return nil, fmt.Errorf("error counting follow requests: %s", err)
}
@ -84,13 +84,13 @@ func (c *converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode
func (c *converter) AccountToAPIAccountPublic(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) {
// count followers
followersCount, err := c.db.CountFollows(ctx, "", a.ID)
followersCount, err := c.db.CountAccountFollowers(ctx, a.ID)
if err != nil {
return nil, fmt.Errorf("error counting followers: %s", err)
}
// count following
followingCount, err := c.db.CountFollows(ctx, a.ID, "")
followingCount, err := c.db.CountAccountFollows(ctx, a.ID)
if err != nil {
return nil, fmt.Errorf("error counting following: %s", err)
}

View File

@ -0,0 +1,151 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// AccountVisible will check if given account is visible to requester, accounting for requester with no auth (i.e is nil), suspensions, disabled local users and account blocks.
func (f *Filter) AccountVisible(ctx context.Context, requester *gtsmodel.Account, account *gtsmodel.Account) (bool, error) {
// By default we assume no auth.
requesterID := noauth
if requester != nil {
// Use provided account ID.
requesterID = requester.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform visibility lookup.
visible, err := f.isAccountVisibleTo(ctx, requester, account)
if err != nil {
return nil, err
}
// Return visibility value.
return &cache.CachedVisibility{
ItemID: account.ID,
RequesterID: requesterID,
Type: cache.VisibilityTypeAccount,
Value: visible,
}, nil
}, "account", requesterID, account.ID)
if err != nil {
return false, err
}
return visibility.Value, nil
}
// isAccountVisibleTo will check if account is visible to requester. It is the "meat" of the logic to Filter{}.AccountVisible() which is called within cache loader callback.
func (f *Filter) isAccountVisibleTo(ctx context.Context, requester *gtsmodel.Account, account *gtsmodel.Account) (bool, error) {
// Check whether target account is visible to anyone.
visible, err := f.isAccountVisible(ctx, account)
if err != nil {
return false, fmt.Errorf("isAccountVisibleTo: error checking account %s visibility: %w", account.ID, err)
}
if !visible {
log.Trace(ctx, "target account is not visible to anyone")
return false, nil
}
if requester == nil {
// It seems stupid, but when un-authed all accounts are
// visible to allow for federation to work correctly.
return true, nil
}
// If requester is not visible, they cannot *see* either.
visible, err = f.isAccountVisible(ctx, requester)
if err != nil {
return false, fmt.Errorf("isAccountVisibleTo: error checking account %s visibility: %w", account.ID, err)
}
if !visible {
log.Trace(ctx, "requesting account cannot see other accounts")
return false, nil
}
// Check whether either blocks the other.
blocked, err := f.state.DB.IsEitherBlocked(ctx,
requester.ID,
account.ID,
)
if err != nil {
return false, fmt.Errorf("isAccountVisibleTo: error checking account blocks: %w", err)
}
if blocked {
log.Trace(ctx, "block exists between accounts")
return false, nil
}
return true, nil
}
// isAccountVisible will check if given account should be visible at all, e.g. it may not be if suspended or disabled.
func (f *Filter) isAccountVisible(ctx context.Context, account *gtsmodel.Account) (bool, error) {
if account.IsLocal() {
// This is a local account.
if account.Username == config.GetHost() {
// This is the instance actor account.
return true, nil
}
// Fetch the local user model for this account.
user, err := f.state.DB.GetUserByAccountID(ctx, account.ID)
if err != nil {
return false, err
}
// Make sure that user is active (i.e. not disabled, not approved etc).
if *user.Disabled || !*user.Approved || user.ConfirmedAt.IsZero() {
log.Trace(ctx, "local account not active")
return false, nil
}
} else {
// This is a remote account.
// Check whether remote account's domain is blocked.
blocked, err := f.state.DB.IsDomainBlocked(ctx, account.Domain)
if err != nil {
return false, err
}
if blocked {
log.Trace(ctx, "remote account domain blocked")
return false, nil
}
}
if !account.SuspendedAt.IsZero() {
log.Trace(ctx, "account suspended")
return false, nil
}
return true, nil
}

View File

@ -0,0 +1,62 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// StatusBoostable checks if given status is boostable by requester, checking boolean status visibility to requester and ultimately the AP status visibility setting.
func (f *Filter) StatusBoostable(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
if status.Visibility == gtsmodel.VisibilityDirect {
log.Trace(ctx, "direct statuses are not boostable")
return false, nil
}
// Check whether status is visible to requesting account.
visible, err := f.StatusVisible(ctx, requester, status)
if err != nil {
return false, err
}
if !visible {
log.Trace(ctx, "status not visible to requesting account")
return false, nil
}
if requester.ID == status.AccountID {
// Status author can always boost non-directs.
return true, nil
}
if status.Visibility == gtsmodel.VisibilityFollowersOnly ||
status.Visibility == gtsmodel.VisibilityMutualsOnly {
log.Trace(ctx, "unauthored %s status not boostable", status.Visibility)
return false, nil
}
if !*status.Boostable {
log.Trace(ctx, "status marked not boostable")
return false, nil
}
return true, nil
}

View File

@ -33,7 +33,7 @@ func (suite *StatusBoostableTestSuite) TestOwnPublicBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@ -44,7 +44,7 @@ func (suite *StatusBoostableTestSuite) TestOwnUnlockedBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@ -55,7 +55,7 @@ func (suite *StatusBoostableTestSuite) TestOwnMutualsOnlyNonInteractiveBoostable
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@ -66,7 +66,7 @@ func (suite *StatusBoostableTestSuite) TestOwnMutualsOnlyBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@ -77,7 +77,7 @@ func (suite *StatusBoostableTestSuite) TestOwnFollowersOnlyBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@ -88,7 +88,7 @@ func (suite *StatusBoostableTestSuite) TestOwnDirectNotBoostable() {
testAccount := suite.testAccounts["local_account_2"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(boostable)
@ -99,7 +99,7 @@ func (suite *StatusBoostableTestSuite) TestOtherPublicBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@ -110,7 +110,7 @@ func (suite *StatusBoostableTestSuite) TestOtherUnlistedBoostable() {
testAccount := suite.testAccounts["local_account_2"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@ -121,7 +121,7 @@ func (suite *StatusBoostableTestSuite) TestOtherFollowersOnlyNotBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(boostable)
@ -132,19 +132,19 @@ func (suite *StatusBoostableTestSuite) TestOtherDirectNotBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(boostable)
}
func (suite *StatusBoostableTestSuite) TestRemoteFollowersOnlyNotVisibleError() {
func (suite *StatusBoostableTestSuite) TestRemoteFollowersOnlyNotVisible() {
testStatus := suite.testStatuses["local_account_1_status_5"]
testAccount := suite.testAccounts["remote_account_1"]
ctx := context.Background()
boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
suite.Assert().Error(err)
boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(boostable)
}

View File

@ -18,46 +18,20 @@
package visibility
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
)
// Filter packages up a bunch of logic for checking whether given statuses or accounts are visible to a requester.
type Filter interface {
// StatusVisible returns true if targetStatus is visible to requestingAccount, based on the
// privacy settings of the status, and any blocks/mutes that might exist between the two accounts
// or account domains, and other relevant accounts mentioned in or replied to by the status.
StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
// noauth is a placeholder ID used in cache lookups
// when there is no authorized account ID to use.
const noauth = "noauth"
// StatusesVisible calls StatusVisible for each status in the statuses slice, and returns a slice of only
// statuses which are visible to the requestingAccount.
StatusesVisible(ctx context.Context, statuses []*gtsmodel.Status, requestingAccount *gtsmodel.Account) ([]*gtsmodel.Status, error)
// StatusHometimelineable returns true if targetStatus should be in the home timeline of the requesting account.
//
// This function will call StatusVisible internally, so it's not necessary to call it beforehand.
StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
// StatusPublictimelineable returns true if targetStatus should be in the public timeline of the requesting account.
//
// This function will call StatusVisible internally, so it's not necessary to call it beforehand.
StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error)
// StatusBoostable returns true if targetStatus can be boosted by the requesting account.
//
// this function will call StatusVisible internally so it's not necessary to call it beforehand.
StatusBoostable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
}
type filter struct {
db db.DB
// Filter packages up a bunch of logic for checking whether
// given statuses or accounts are visible to a requester.
type Filter struct {
state *state.State
}
// NewFilter returns a new Filter interface that will use the provided database.
func NewFilter(db db.DB) Filter {
return &filter{
db: db,
}
func NewFilter(state *state.State) *Filter {
return &Filter{state: state}
}

View File

@ -30,6 +30,7 @@ type FilterStandardTestSuite struct {
// standard suite interfaces
suite.Suite
db db.DB
state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@ -43,7 +44,7 @@ type FilterStandardTestSuite struct {
testMentions map[string]*gtsmodel.Mention
testFollows map[string]*gtsmodel.Follow
filter visibility.Filter
filter *visibility.Filter
}
func (suite *FilterStandardTestSuite) SetupSuite() {
@ -60,14 +61,13 @@ func (suite *FilterStandardTestSuite) SetupSuite() {
}
func (suite *FilterStandardTestSuite) SetupTest() {
var state state.State
state.Caches.Init()
suite.state.Caches.Init()
testrig.InitTestConfig()
testrig.InitTestLog()
suite.db = testrig.NewTestDB(&state)
suite.filter = visibility.NewFilter(suite.db)
suite.db = testrig.NewTestDB(&suite.state)
suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)
}

View File

@ -0,0 +1,165 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// StatusHomeTimelineable checks if given status should be included on owner's home timeline. Primarily relying on status visibility to owner and the AP visibility setting, but also taking into account thread replies etc.
func (f *Filter) StatusHomeTimelineable(ctx context.Context, owner *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// By default we assume no auth.
requesterID := noauth
if owner != nil {
// Use provided account ID.
requesterID = owner.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform timeline visibility lookup.
visible, err := f.isStatusHomeTimelineable(ctx, owner, status)
if err != nil {
return nil, err
}
// Return visibility value.
return &cache.CachedVisibility{
ItemID: status.ID,
RequesterID: requesterID,
Type: cache.VisibilityTypeHome,
Value: visible,
}, nil
}, "home", requesterID, status.ID)
if err != nil {
if err == cache.SentinelError {
// Filter-out our temporary
// race-condition error.
return false, nil
}
return false, err
}
return visibility.Value, nil
}
func (f *Filter) isStatusHomeTimelineable(ctx context.Context, owner *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
if status.CreatedAt.After(time.Now().Add(24 * time.Hour)) {
// Statuses made over 1 day in the future we don't show...
log.Warnf(ctx, "status >24hrs in the future: %+v", status)
return false, nil
}
// Check whether status is visible to timeline owner.
visible, err := f.StatusVisible(ctx, owner, status)
if err != nil {
return false, err
}
if !visible {
log.Trace(ctx, "status not visible to timeline owner")
return false, nil
}
if status.AccountID == owner.ID {
// Author can always see their status.
return true, nil
}
if status.MentionsAccount(owner.ID) {
// Can always see when you are mentioned.
return true, nil
}
var (
parent *gtsmodel.Status
included bool
oneAuthor bool
)
for parent = status; parent.InReplyToURI != ""; {
// Fetch next parent to lookup.
parentID := parent.InReplyToID
if parentID == "" {
log.Warnf(ctx, "status not yet deref'd: %s", parent.InReplyToURI)
return false, cache.SentinelError
}
// Get the next parent in the chain from DB.
parent, err = f.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
parentID,
)
if err != nil {
return false, fmt.Errorf("isStatusHomeTimelineable: error getting status parent %s: %w", parentID, err)
}
if (parent.AccountID == owner.ID) ||
parent.MentionsAccount(owner.ID) {
// Owner is in / mentioned in
// this status thread.
included = true
break
}
if oneAuthor {
// Check if this is a single-author status thread.
oneAuthor = (parent.AccountID == status.AccountID)
}
}
if parent != status && !included && !oneAuthor {
log.Trace(ctx, "ignoring visible reply to conversation thread excluding owner")
return false, nil
}
// At this point status is either a top-level status, a reply in a single
// author thread (e.g. "this is my weird-ass take and here is why 1/10 🧵"),
// or a thread mentioning / including timeline owner.
if status.Visibility == gtsmodel.VisibilityFollowersOnly ||
status.Visibility == gtsmodel.VisibilityMutualsOnly {
// Followers/mutuals only post that already passed the status
// visibility check, (i.e. we follow / mutuals with author).
return true, nil
}
// Ensure owner follows author of public/unlocked status.
follow, err := f.state.DB.IsFollowing(ctx,
owner.ID,
status.AccountID,
)
if err != nil {
return false, fmt.Errorf("isStatusHomeTimelineable: error checking follow %s->%s: %w", owner.ID, status.AccountID, err)
}
if !follow {
log.Trace(ctx, "ignoring visible status from unfollowed author")
return false, nil
}
return true, nil
}

View File

@ -25,86 +25,77 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusStatusHometimelineableTestSuite struct {
type StatusStatusHomeTimelineableTestSuite struct {
FilterStandardTestSuite
}
func (suite *StatusStatusHometimelineableTestSuite) TestOwnStatusHometimelineable() {
func (suite *StatusStatusHomeTimelineableTestSuite) TestOwnStatusHomeTimelineable() {
testStatus := suite.testStatuses["local_account_1_status_1"]
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(timelineable)
}
func (suite *StatusStatusHometimelineableTestSuite) TestFollowingStatusHometimelineable() {
func (suite *StatusStatusHomeTimelineableTestSuite) TestFollowingStatusHomeTimelineable() {
testStatus := suite.testStatuses["local_account_2_status_1"]
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(timelineable)
}
func (suite *StatusStatusHometimelineableTestSuite) TestNotFollowingStatusHometimelineable() {
func (suite *StatusStatusHomeTimelineableTestSuite) TestNotFollowingStatusHomeTimelineable() {
testStatus := suite.testStatuses["remote_account_1_status_1"]
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(timelineable)
}
func (suite *StatusStatusHometimelineableTestSuite) TestStatusTooNewNotTimelineable() {
func (suite *StatusStatusHomeTimelineableTestSuite) TestStatusTooNewNotTimelineable() {
testStatus := &gtsmodel.Status{}
*testStatus = *suite.testStatuses["local_account_1_status_1"]
var err error
testStatus.ID, err = id.NewULIDFromTime(time.Now().Add(10 * time.Minute))
if err != nil {
suite.FailNow(err.Error())
}
testStatus.CreatedAt = time.Now().Add(25 * time.Hour)
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(timelineable)
}
func (suite *StatusStatusHometimelineableTestSuite) TestStatusNotTooNewTimelineable() {
func (suite *StatusStatusHomeTimelineableTestSuite) TestStatusNotTooNewTimelineable() {
testStatus := &gtsmodel.Status{}
*testStatus = *suite.testStatuses["local_account_1_status_1"]
var err error
testStatus.ID, err = id.NewULIDFromTime(time.Now().Add(4 * time.Minute))
if err != nil {
suite.FailNow(err.Error())
}
testStatus.CreatedAt = time.Now().Add(23 * time.Hour)
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(timelineable)
}
func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly() {
func (suite *StatusStatusHomeTimelineableTestSuite) TestChainReplyFollowersOnly() {
ctx := context.Background()
// This scenario makes sure that we don't timeline a status which is a followers-only
@ -112,9 +103,8 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
// timeline owner account doesn't follow.
//
// In other words, remote_account_1 posts a followers-only status, which local_account_1 replies to;
// THEN, local_account_1 replies to their own reply. We don't want this last status to appear
// in the timeline of local_account_2, even though they follow local_account_1, because they
// *don't* follow remote_account_1.
// THEN, local_account_1 replies to their own reply. None of these statuses should appear to
// local_account_2 since they don't follow the original parent.
//
// See: https://github.com/superseriousbusiness/gotosocial/issues/501
@ -152,7 +142,7 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
suite.FailNow(err.Error())
}
// this status should not be hometimelineable for local_account_2
originalStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, originalStatus, timelineOwnerAccount)
originalStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, originalStatus)
suite.NoError(err)
suite.False(originalStatusTimelineable)
@ -185,8 +175,8 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
if err := suite.db.PutStatus(ctx, firstReplyStatus); err != nil {
suite.FailNow(err.Error())
}
// this status should not be hometimelineable for local_account_2
firstReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, firstReplyStatus, timelineOwnerAccount)
// this status should be hometimelineable for local_account_2
firstReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, firstReplyStatus)
suite.NoError(err)
suite.False(firstReplyStatusTimelineable)
@ -221,12 +211,12 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
}
// this status should ALSO not be hometimelineable for local_account_2
secondReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, secondReplyStatus, timelineOwnerAccount)
secondReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, secondReplyStatus)
suite.NoError(err)
suite.False(secondReplyStatusTimelineable)
}
func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnlocked() {
func (suite *StatusStatusHomeTimelineableTestSuite) TestChainReplyPublicAndUnlocked() {
ctx := context.Background()
// This scenario is exactly the same as the above test, but for a mix of unlocked + public posts
@ -265,7 +255,7 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnloc
suite.FailNow(err.Error())
}
// this status should not be hometimelineable for local_account_2
originalStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, originalStatus, timelineOwnerAccount)
originalStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, originalStatus)
suite.NoError(err)
suite.False(originalStatusTimelineable)
@ -299,7 +289,7 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnloc
suite.FailNow(err.Error())
}
// this status should not be hometimelineable for local_account_2
firstReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, firstReplyStatus, timelineOwnerAccount)
firstReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, firstReplyStatus)
suite.NoError(err)
suite.False(firstReplyStatusTimelineable)
@ -334,11 +324,11 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnloc
}
// this status should ALSO not be hometimelineable for local_account_2
secondReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, secondReplyStatus, timelineOwnerAccount)
secondReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, secondReplyStatus)
suite.NoError(err)
suite.False(secondReplyStatusTimelineable)
}
func TestStatusHometimelineableTestSuite(t *testing.T) {
suite.Run(t, new(StatusStatusHometimelineableTestSuite))
func TestStatusHomeTimelineableTestSuite(t *testing.T) {
suite.Run(t, new(StatusStatusHomeTimelineableTestSuite))
}

View File

@ -0,0 +1,121 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// StatusHomeTimelineable checks if given status should be included on requester's public timeline. Primarily relying on status visibility to requester and the AP visibility setting, and ignoring conversation threads.
func (f *Filter) StatusPublicTimelineable(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// By default we assume no auth.
requesterID := noauth
if requester != nil {
// Use provided account ID.
requesterID = requester.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform timeline visibility lookup.
visible, err := f.isStatusPublicTimelineable(ctx, requester, status)
if err != nil {
return nil, err
}
// Return visibility value.
return &cache.CachedVisibility{
ItemID: status.ID,
RequesterID: requesterID,
Type: cache.VisibilityTypePublic,
Value: visible,
}, nil
}, "public", requesterID, status.ID)
if err != nil {
if err == cache.SentinelError {
// Filter-out our temporary
// race-condition error.
return false, nil
}
return false, err
}
return visibility.Value, nil
}
func (f *Filter) isStatusPublicTimelineable(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
if status.CreatedAt.After(time.Now().Add(24 * time.Hour)) {
// Statuses made over 1 day in the future we don't show...
log.Warnf(ctx, "status >24hrs in the future: %+v", status)
return false, nil
}
// Don't show boosts on timeline.
if status.BoostOfID != "" {
return false, nil
}
// Check whether status is visible to requesting account.
visible, err := f.StatusVisible(ctx, requester, status)
if err != nil {
return false, err
}
if !visible {
log.Trace(ctx, "status not visible to timeline requester")
return false, nil
}
for parent := status; parent.InReplyToURI != ""; {
// Fetch next parent to lookup.
parentID := parent.InReplyToID
if parentID == "" {
log.Warnf(ctx, "status not yet deref'd: %s", parent.InReplyToURI)
return false, cache.SentinelError
}
// Get the next parent in the chain from DB.
parent, err = f.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
parentID,
)
if err != nil {
return false, fmt.Errorf("isStatusHomeTimelineable: error getting status parent %s: %w", parentID, err)
}
if parent.AccountID != status.AccountID {
// This is not a single author reply-chain-thread,
// instead is an actualy conversation. Don't timeline.
log.Trace(ctx, "ignoring multi-author reply-chain")
return false, nil
}
}
// This is either a visible status in a
// single-author thread, or a visible top
// level status. Show on public timeline.
return true, nil
}

View File

@ -1,230 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// relevantAccounts denotes accounts that are replied to, boosted by, or mentioned in a status.
type relevantAccounts struct {
// Who wrote the status
Account *gtsmodel.Account
// Who is the status replying to
InReplyToAccount *gtsmodel.Account
// Which accounts are mentioned (tagged) in the status
MentionedAccounts []*gtsmodel.Account
// Who authed the boosted status
BoostedAccount *gtsmodel.Account
// If the boosted status replies to another account, who does it reply to?
BoostedInReplyToAccount *gtsmodel.Account
// Who is mentioned (tagged) in the boosted status
BoostedMentionedAccounts []*gtsmodel.Account
}
func (f *filter) relevantAccounts(ctx context.Context, status *gtsmodel.Status, getBoosted bool) (*relevantAccounts, error) {
relAccts := &relevantAccounts{
MentionedAccounts: []*gtsmodel.Account{},
BoostedMentionedAccounts: []*gtsmodel.Account{},
}
/*
Here's what we need to try and extract from the status:
// 1. Who wrote the status
Account *gtsmodel.Account
// 2. Who is the status replying to
InReplyToAccount *gtsmodel.Account
// 3. Which accounts are mentioned (tagged) in the status
MentionedAccounts []*gtsmodel.Account
if getBoosted:
// 4. Who wrote the boosted status
BoostedAccount *gtsmodel.Account
// 5. If the boosted status replies to another account, who does it reply to?
BoostedInReplyToAccount *gtsmodel.Account
// 6. Who is mentioned (tagged) in the boosted status
BoostedMentionedAccounts []*gtsmodel.Account
*/
// 1. Account.
// Account might be set on the status already
if status.Account != nil {
// it was set
relAccts.Account = status.Account
} else {
// it wasn't set, so get it from the db
account, err := f.db.GetAccountByID(ctx, status.AccountID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting account with id %s: %s", status.AccountID, err)
}
// set it on the status in case we need it further along
status.Account = account
// set it on relevant accounts
relAccts.Account = account
}
// 2. InReplyToAccount
// only get this if InReplyToAccountID is set
if status.InReplyToAccountID != "" {
// InReplyToAccount might be set on the status already
if status.InReplyToAccount != nil {
// it was set
relAccts.InReplyToAccount = status.InReplyToAccount
} else {
// it wasn't set, so get it from the db
inReplyToAccount, err := f.db.GetAccountByID(ctx, status.InReplyToAccountID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting inReplyToAccount with id %s: %s", status.InReplyToAccountID, err)
}
// set it on the status in case we need it further along
status.InReplyToAccount = inReplyToAccount
// set it on relevant accounts
relAccts.InReplyToAccount = inReplyToAccount
}
}
// 3. MentionedAccounts
// First check if status.Mentions is populated with all mentions that correspond to status.MentionIDs
for _, mID := range status.MentionIDs {
if mID == "" {
continue
}
if !idIn(mID, status.Mentions) {
// mention with ID isn't in status.Mentions
mention, err := f.db.GetMention(ctx, mID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting mention with id %s: %s", mID, err)
}
if mention == nil {
return nil, fmt.Errorf("relevantAccounts: mention with id %s was nil", mID)
}
status.Mentions = append(status.Mentions, mention)
}
}
// now filter mentions to make sure we only have mentions with a corresponding ID
nm := []*gtsmodel.Mention{}
for _, m := range status.Mentions {
if m == nil {
continue
}
if mentionIn(m, status.MentionIDs) {
nm = append(nm, m)
relAccts.MentionedAccounts = append(relAccts.MentionedAccounts, m.TargetAccount)
}
}
status.Mentions = nm
if len(status.Mentions) != len(status.MentionIDs) {
return nil, errors.New("relevantAccounts: mentions length did not correspond with mentionIDs length")
}
// if getBoosted is set, we should check the same properties on the boosted account as well
if getBoosted {
// 4, 5, 6. Boosted status items
// get the boosted status if it's not set on the status already
if status.BoostOfID != "" && status.BoostOf == nil {
boostedStatus, err := f.db.GetStatusByID(ctx, status.BoostOfID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting boosted status with id %s: %s", status.BoostOfID, err)
}
status.BoostOf = boostedStatus
}
if status.BoostOf != nil {
// return relevant accounts for the boosted status
boostedRelAccts, err := f.relevantAccounts(ctx, status.BoostOf, false) // false because we don't want to recurse
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting relevant accounts of boosted status %s: %s", status.BoostOf.ID, err)
}
relAccts.BoostedAccount = boostedRelAccts.Account
relAccts.BoostedInReplyToAccount = boostedRelAccts.InReplyToAccount
relAccts.BoostedMentionedAccounts = boostedRelAccts.MentionedAccounts
}
}
return relAccts, nil
}
// domainBlockedRelevant checks through all relevant accounts attached to a status
// to make sure none of them are domain blocked by this instance.
func (f *filter) domainBlockedRelevant(ctx context.Context, r *relevantAccounts) (bool, error) {
domains := []string{}
if r.Account != nil {
domains = append(domains, r.Account.Domain)
}
if r.InReplyToAccount != nil {
domains = append(domains, r.InReplyToAccount.Domain)
}
for _, a := range r.MentionedAccounts {
if a != nil {
domains = append(domains, a.Domain)
}
}
if r.BoostedAccount != nil {
domains = append(domains, r.BoostedAccount.Domain)
}
if r.BoostedInReplyToAccount != nil {
domains = append(domains, r.BoostedInReplyToAccount.Domain)
}
for _, a := range r.BoostedMentionedAccounts {
if a != nil {
domains = append(domains, a.Domain)
}
}
return f.db.AreDomainsBlocked(ctx, domains)
}
func idIn(id string, mentions []*gtsmodel.Mention) bool {
for _, m := range mentions {
if m == nil {
continue
}
if m.ID == id {
return true
}
}
return false
}
func mentionIn(mention *gtsmodel.Mention, ids []string) bool {
if mention == nil {
return false
}
for _, i := range ids {
if mention.ID == i {
return true
}
}
return false
}

View File

@ -0,0 +1,217 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// StatusesVisible calls StatusVisible for each status in the statuses slice, and returns a slice of only statuses which are visible to the requester.
func (f *Filter) StatusesVisible(ctx context.Context, requester *gtsmodel.Account, statuses []*gtsmodel.Status) ([]*gtsmodel.Status, error) {
// Preallocate slice of maximum possible length.
filtered := make([]*gtsmodel.Status, 0, len(statuses))
for _, status := range statuses {
// Check whether status is visible to requester.
visible, err := f.StatusVisible(ctx, requester, status)
if err != nil {
return nil, err
}
if visible {
// Add filtered status to ret slice.
filtered = append(filtered, status)
}
}
return filtered, nil
}
// StatusVisible will check if given status is visible to requester, accounting for requester with no auth (i.e is nil), suspensions, disabled local users, account blocks and status privacy.
func (f *Filter) StatusVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// By default we assume no auth.
requesterID := noauth
if requester != nil {
// Use provided account ID.
requesterID = requester.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform visibility lookup.
visible, err := f.isStatusVisible(ctx, requester, status)
if err != nil {
return nil, err
}
// Return visibility value.
return &cache.CachedVisibility{
ItemID: status.ID,
RequesterID: requesterID,
Type: cache.VisibilityTypeStatus,
Value: visible,
}, nil
}, "status", requesterID, status.ID)
if err != nil {
return false, err
}
return visibility.Value, nil
}
// isStatusVisible will check if status is visible to requester. It is the "meat" of the logic to Filter{}.StatusVisible() which is called within cache loader callback.
func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// Ensure that status is fully populated for further processing.
if err := f.state.DB.PopulateStatus(ctx, status); err != nil {
return false, err
}
// Check whether status accounts are visible to the requester.
visible, err := f.areStatusAccountsVisible(ctx, requester, status)
if err != nil {
return false, fmt.Errorf("isStatusVisible: error checking status %s account visibility: %w", status.ID, err)
} else if !visible {
return false, nil
}
if status.Visibility == gtsmodel.VisibilityPublic {
// This status will be visible to all.
return true, nil
}
if requester == nil {
// This request is WITHOUT auth, and status is NOT public.
log.Trace(ctx, "unauthorized request to non-public status")
return false, nil
}
if status.Visibility == gtsmodel.VisibilityUnlocked {
// This status is visible to all auth'd accounts.
return true, nil
}
if requester.ID == status.AccountID {
// Author can always see their own status.
return true, nil
}
if status.MentionsAccount(requester.ID) {
// Status mentions the requesting account.
return true, nil
}
if status.BoostOf != nil {
if !status.BoostOf.MentionsPopulated() {
// Boosted status needs its mentions populating, fetch these from database.
status.BoostOf.Mentions, err = f.state.DB.GetMentions(ctx, status.BoostOf.MentionIDs)
if err != nil {
return false, fmt.Errorf("isStatusVisible: error populating boosted status %s mentions: %w", status.BoostOfID, err)
}
}
if status.BoostOf.MentionsAccount(requester.ID) {
// Boosted status mentions the requesting account.
return true, nil
}
}
switch status.Visibility {
case gtsmodel.VisibilityFollowersOnly:
// Check requester follows status author.
follows, err := f.state.DB.IsFollowing(ctx,
requester.ID,
status.AccountID,
)
if err != nil {
return false, fmt.Errorf("isStatusVisible: error checking follow %s->%s: %w", requester.ID, status.AccountID, err)
}
if !follows {
log.Trace(ctx, "follow-only status not visible to requester")
return false, nil
}
return true, nil
case gtsmodel.VisibilityMutualsOnly:
// Check mutual following between requester and author.
mutuals, err := f.state.DB.IsMutualFollowing(ctx,
requester.ID,
status.AccountID,
)
if err != nil {
return false, fmt.Errorf("isStatusVisible: error checking mutual follow %s<->%s: %w", requester.ID, status.AccountID, err)
}
if !mutuals {
log.Trace(ctx, "mutual-only status not visible to requester")
return false, nil
}
return true, nil
case gtsmodel.VisibilityDirect:
log.Trace(ctx, "direct status not visible to requester")
return false, nil
default:
log.Warnf(ctx, "unexpected status visibility %s for %s", status.Visibility, status.URI)
return false, nil
}
}
// areStatusAccountsVisible calls Filter{}.AccountVisible() on status author and the status boost-of (if set) author, returning visibility of status (and boost-of) to requester.
func (f *Filter) areStatusAccountsVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// Check whether status author's account is visible to requester.
visible, err := f.AccountVisible(ctx, requester, status.Account)
if err != nil {
return false, err
}
if !visible {
log.Trace(ctx, "status author not visible to requester")
return false, nil
}
if status.BoostOfID != "" {
// This is a boosted status.
if status.AccountID == status.BoostOfAccountID {
// Some clout-chaser boosted their own status, tch.
return true, nil
}
// Check whether boosted status author's account is visible to requester.
visible, err := f.AccountVisible(ctx, requester, status.BoostOfAccount)
if err != nil {
return false, err
}
if !visible {
log.Trace(ctx, "boosted status author not visible to requester")
return false, nil
}
}
return true, nil
}

View File

@ -34,7 +34,7 @@ func (suite *StatusVisibleTestSuite) TestOwnStatusVisible() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
@ -48,7 +48,7 @@ func (suite *StatusVisibleTestSuite) TestOwnDMVisible() {
suite.NoError(err)
testAccount := suite.testAccounts["local_account_2"]
visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
@ -62,7 +62,7 @@ func (suite *StatusVisibleTestSuite) TestDMVisibleToTarget() {
suite.NoError(err)
testAccount := suite.testAccounts["local_account_1"]
visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
@ -76,7 +76,7 @@ func (suite *StatusVisibleTestSuite) TestDMNotVisibleIfNotMentioned() {
suite.NoError(err)
testAccount := suite.testAccounts["admin_account"]
visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
@ -92,7 +92,7 @@ func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotMutuals() {
suite.NoError(err)
testAccount := suite.testAccounts["local_account_2"]
visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
@ -108,12 +108,54 @@ func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotFollowing() {
suite.NoError(err)
testAccount := suite.testAccounts["admin_account"]
visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
}
func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotMutualsCached() {
ctx := context.Background()
testStatusID := suite.testStatuses["local_account_1_status_4"].ID
testStatus, err := suite.db.GetStatusByID(ctx, testStatusID)
suite.NoError(err)
testAccount := suite.testAccounts["local_account_2"]
// Perform a status visibility check while mutuals, this shsould be true.
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
err = suite.db.DeleteFollowByID(ctx, suite.testFollows["local_account_2_local_account_1"].ID)
suite.NoError(err)
// Perform a status visibility check after unfollow, this should be false.
visible, err = suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
}
func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotFollowingCached() {
ctx := context.Background()
testStatusID := suite.testStatuses["local_account_1_status_5"].ID
testStatus, err := suite.db.GetStatusByID(ctx, testStatusID)
suite.NoError(err)
testAccount := suite.testAccounts["admin_account"]
// Perform a status visibility check while following, this shsould be true.
visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
err = suite.db.DeleteFollowByID(ctx, suite.testFollows["admin_account_local_account_1"].ID)
suite.NoError(err)
// Perform a status visibility check after unfollow, this should be false.
visible, err = suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
}
func TestStatusVisibleTestSuite(t *testing.T) {
suite.Run(t, new(StatusVisibleTestSuite))
}

View File

@ -1,60 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
func (f *filter) StatusBoostable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) {
// if the status isn't visible, it certainly isn't boostable
visible, err := f.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return false, fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)
}
if !visible {
return false, errors.New("status is not visible")
}
// direct messages are never boostable, even if they're visible
if targetStatus.Visibility == gtsmodel.VisibilityDirect {
log.Trace(ctx, "status is not boostable because it is a DM")
return false, nil
}
// the original account should always be able to boost its own non-DM statuses
if requestingAccount.ID == targetStatus.Account.ID {
log.Trace(ctx, "status is boostable because author is booster")
return true, nil
}
// if status is followers-only and not the author's, it is not boostable
if targetStatus.Visibility == gtsmodel.VisibilityFollowersOnly {
log.Trace(ctx, "status not boostable because it is followers-only")
return false, nil
}
// otherwise, status is as boostable as it says it is
log.Trace(ctx, "defaulting to status.boostable value")
return *targetStatus.Boostable, nil
}

View File

@ -1,126 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"fmt"
"time"
"codeberg.org/gruf/go-kv"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
func (f *filter) StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
l := log.WithContext(ctx).
WithFields(kv.Fields{{"statusID", targetStatus.ID}}...)
// don't timeline statuses more than 5 min in the future
maxID, err := id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
if err != nil {
return false, err
}
if targetStatus.ID > maxID {
l.Debug("status not hometimelineable because it's from more than 5 minutes in the future")
return false, nil
}
// status owner should always be able to see their own status in their timeline so we can return early if this is the case
if targetStatus.AccountID == timelineOwnerAccount.ID {
return true, nil
}
v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err)
}
if !v {
l.Debug("status is not hometimelineable because it's not visible to the requester")
return false, nil
}
for _, m := range targetStatus.Mentions {
if m.TargetAccountID == timelineOwnerAccount.ID {
// if we're mentioned we should be able to see the post
return true, nil
}
}
// check we follow the originator of the status
if targetStatus.Account == nil {
tsa, err := f.db.GetAccountByID(ctx, targetStatus.AccountID)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error getting status author account with id %s: %s", targetStatus.AccountID, err)
}
targetStatus.Account = tsa
}
following, err := f.db.IsFollowing(ctx, timelineOwnerAccount, targetStatus.Account)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error checking if %s follows %s: %s", timelineOwnerAccount.ID, targetStatus.AccountID, err)
}
if !following {
return false, nil
}
// Don't timeline a status whose parent hasn't been dereferenced yet or can't be dereferenced.
// If we have the reply to URI but don't have an ID for the replied-to account or the replied-to status in our database, we haven't dereferenced it yet.
if targetStatus.InReplyToURI != "" && (targetStatus.InReplyToID == "" || targetStatus.InReplyToAccountID == "") {
return false, nil
}
// if a status replies to an ID we know in the database, we need to check that parent status too
if targetStatus.InReplyToID != "" {
// pin the reply to status on to this status if it hasn't been done already
if targetStatus.InReplyTo == nil {
rs, err := f.db.GetStatusByID(ctx, targetStatus.InReplyToID)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error getting replied to status with id %s: %s", targetStatus.InReplyToID, err)
}
targetStatus.InReplyTo = rs
}
// pin the reply to account on to this status if it hasn't been done already
if targetStatus.InReplyToAccount == nil {
ra, err := f.db.GetAccountByID(ctx, targetStatus.InReplyToAccountID)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error getting replied to account with id %s: %s", targetStatus.InReplyToAccountID, err)
}
targetStatus.InReplyToAccount = ra
}
// if it's a reply to the timelineOwnerAccount, we don't need to check if the timelineOwnerAccount follows itself, just return true, they can see it
if targetStatus.InReplyToAccountID == timelineOwnerAccount.ID {
return true, nil
}
// make sure the parent status is also home timelineable, otherwise we shouldn't timeline this one either
parentStatusTimelineable, err := f.StatusHometimelineable(ctx, targetStatus.InReplyTo, timelineOwnerAccount)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error checking timelineability of parent status %s of status %s: %s", targetStatus.InReplyToID, targetStatus.ID, err)
}
if !parentStatusTimelineable {
return false, nil
}
}
return true, nil
}

View File

@ -1,72 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"fmt"
"time"
"codeberg.org/gruf/go-kv"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
func (f *filter) StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
l := log.WithContext(ctx).
WithFields(kv.Fields{{"statusID", targetStatus.ID}}...)
// don't timeline statuses more than 5 min in the future
maxID, err := id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
if err != nil {
return false, err
}
if targetStatus.ID > maxID {
l.Debug("status not hometimelineable because it's from more than 5 minutes in the future")
return false, nil
}
// Don't timeline boosted statuses
if targetStatus.BoostOfID != "" {
return false, nil
}
// Don't timeline a reply
if targetStatus.InReplyToURI != "" || targetStatus.InReplyToID != "" || targetStatus.InReplyToAccountID != "" {
return false, nil
}
// status owner should always be able to see their own status in their timeline so we can return early if this is the case
if timelineOwnerAccount != nil && targetStatus.AccountID == timelineOwnerAccount.ID {
return true, nil
}
v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount)
if err != nil {
return false, fmt.Errorf("StatusPublictimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err)
}
if !v {
l.Debug("status is not publicTimelineable because it's not visible to the requester")
return false, nil
}
return true, nil
}

View File

@ -1,252 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package visibility
import (
"context"
"fmt"
"codeberg.org/gruf/go-kv"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) {
l := log.WithContext(ctx).
WithFields(kv.Fields{{"statusID", targetStatus.ID}}...)
// Fetch any relevant accounts for the target status
const getBoosted = true
relevantAccounts, err := f.relevantAccounts(ctx, targetStatus, getBoosted)
if err != nil {
l.Debugf("error pulling relevant accounts for status %s: %s", targetStatus.ID, err)
return false, fmt.Errorf("StatusVisible: error pulling relevant accounts for status %s: %s", targetStatus.ID, err)
}
// Check we have determined a target account
targetAccount := relevantAccounts.Account
if targetAccount == nil {
l.Trace("target account is not set")
return false, nil
}
// Check for domain blocks among relevant accounts
domainBlocked, err := f.domainBlockedRelevant(ctx, relevantAccounts)
if err != nil {
l.Debugf("error checking domain block: %s", err)
return false, fmt.Errorf("error checking domain block: %s", err)
} else if domainBlocked {
return false, nil
}
// if target account is suspended then don't show the status
if !targetAccount.SuspendedAt.IsZero() {
l.Trace("target account suspended at is not zero")
return false, nil
}
// if the target user doesn't exist (anymore) then the status also shouldn't be visible
// note: we only do this for local users
if targetAccount.Domain == "" {
targetUser, err := f.db.GetUserByAccountID(ctx, targetAccount.ID)
if err != nil {
l.Debug("target user could not be selected")
if err == db.ErrNoEntries {
return false, nil
}
return false, fmt.Errorf("StatusVisible: db error selecting user for local target account %s: %s", targetAccount.ID, err)
}
// if target user is disabled, not yet approved, or not confirmed then don't show the status
// (although in the latter two cases it's unlikely they posted a status yet anyway, but you never know!)
if *targetUser.Disabled || !*targetUser.Approved || targetUser.ConfirmedAt.IsZero() {
l.Trace("target user is disabled, not approved, or not confirmed")
return false, nil
}
}
// If requesting account is nil, that means whoever requested the status didn't auth, or their auth failed.
// In this case, we can still serve the status if it's public, otherwise we definitely shouldn't.
if requestingAccount == nil {
if targetStatus.Visibility == gtsmodel.VisibilityPublic {
return true, nil
}
l.Trace("requesting account is nil but the target status isn't public")
return false, nil
}
// if the requesting user doesn't exist (anymore) then the status also shouldn't be visible
// note: we only do this for local users
if requestingAccount.Domain == "" {
requestingUser, err := f.db.GetUserByAccountID(ctx, requestingAccount.ID)
if err != nil {
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem
l.Debug("requesting user could not be selected")
if err == db.ErrNoEntries {
return false, nil
}
return false, fmt.Errorf("StatusVisible: db error selecting user for local requesting account %s: %s", requestingAccount.ID, err)
}
// okay, user exists, so make sure it has full privileges/is confirmed/approved
if *requestingUser.Disabled || !*requestingUser.Approved || requestingUser.ConfirmedAt.IsZero() {
l.Trace("requesting account is local but corresponding user is either disabled, not approved, or not confirmed")
return false, nil
}
}
// if requesting account is suspended then don't show the status -- although they probably shouldn't have gotten
// this far (ie., been authed) in the first place: this is just for safety.
if !requestingAccount.SuspendedAt.IsZero() {
l.Trace("requesting account is suspended")
return false, nil
}
// if the target status belongs to the requesting account, they should always be able to view it at this point
if targetStatus.AccountID == requestingAccount.ID {
return true, nil
}
// At this point we have a populated targetAccount, targetStatus, and requestingAccount, so we can check for blocks and whathaveyou
// First check if a block exists directly between the target account (which authored the status) and the requesting account.
if blocked, err := f.db.IsBlocked(ctx, targetAccount.ID, requestingAccount.ID, true); err != nil {
l.Debugf("something went wrong figuring out if the accounts have a block: %s", err)
return false, err
} else if blocked {
// don't allow the status to be viewed if a block exists in *either* direction between these two accounts, no creepy stalking please
l.Trace("a block exists between requesting account and target account")
return false, nil
}
// If not in reply to the requesting account, check if inReplyToAccount is blocked
if relevantAccounts.InReplyToAccount != nil && relevantAccounts.InReplyToAccount.ID != requestingAccount.ID {
if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.InReplyToAccount.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and reply to account")
return false, nil
}
}
// status boosts accounts id
if relevantAccounts.BoostedAccount != nil {
if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedAccount.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and boosted account")
return false, nil
}
}
// status boosts a reply to account id
if relevantAccounts.BoostedInReplyToAccount != nil {
if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedInReplyToAccount.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and boosted reply to account")
return false, nil
}
}
// boost mentions accounts
for _, a := range relevantAccounts.BoostedMentionedAccounts {
if a == nil {
continue
}
if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and a boosted mentioned account")
return false, nil
}
}
// Iterate mentions to check for blocks or requester mentions
isMentioned, blockAmongMentions := false, false
for _, a := range relevantAccounts.MentionedAccounts {
if a == nil {
continue
}
if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
blockAmongMentions = true
break
}
if a.ID == requestingAccount.ID {
isMentioned = true
}
}
if blockAmongMentions {
l.Trace("a block exists between requesting account and a mentioned account")
return false, nil
} else if isMentioned {
// Requester mentioned, should always be visible
return true, nil
}
// at this point we know neither account blocks the other, or another account mentioned or otherwise referred to in the status
// that means it's now just a matter of checking the visibility settings of the status itself
switch targetStatus.Visibility {
case gtsmodel.VisibilityPublic, gtsmodel.VisibilityUnlocked:
// no problem here
case gtsmodel.VisibilityFollowersOnly:
// Followers-only post, check for a one-way follow to target
follows, err := f.db.IsFollowing(ctx, requestingAccount, targetAccount)
if err != nil {
return false, err
}
if !follows {
l.Trace("requested status is followers only but requesting account is not a follower")
return false, nil
}
case gtsmodel.VisibilityMutualsOnly:
// Mutuals-only post, check for a mutual follow
mutuals, err := f.db.IsMutualFollowing(ctx, requestingAccount, targetAccount)
if err != nil {
return false, err
}
if !mutuals {
l.Trace("requested status is mutuals only but accounts aren't mufos")
return false, nil
}
case gtsmodel.VisibilityDirect:
l.Trace("requesting account requests a direct status it's not mentioned in")
return false, nil // it's not mentioned -_-
}
// If we reached here, all is okay
return true, nil
}
func (f *filter) StatusesVisible(ctx context.Context, statuses []*gtsmodel.Status, requestingAccount *gtsmodel.Account) ([]*gtsmodel.Status, error) {
filtered := []*gtsmodel.Status{}
for _, s := range statuses {
visible, err := f.StatusVisible(ctx, s, requestingAccount)
if err != nil {
return nil, err
}
if visible {
filtered = append(filtered, s)
}
}
return filtered, nil
}

View File

@ -20,43 +20,55 @@ EXPECT=$(cat <<"EOF"
"account-max-size": 99,
"account-sweep-freq": 1000000000,
"account-ttl": 10800000000000,
"block-max-size": 100,
"block-sweep-freq": 30000000000,
"block-ttl": 300000000000,
"domain-block-max-size": 1000,
"block-max-size": 1000,
"block-sweep-freq": 60000000000,
"block-ttl": 1800000000000,
"domain-block-max-size": 2000,
"domain-block-sweep-freq": 60000000000,
"domain-block-ttl": 86400000000000,
"emoji-category-max-size": 100,
"emoji-category-sweep-freq": 30000000000,
"emoji-category-ttl": 300000000000,
"emoji-max-size": 500,
"emoji-sweep-freq": 30000000000,
"emoji-ttl": 300000000000,
"media-max-size": 500,
"media-sweep-freq": 30000000000,
"media-ttl": 300000000000,
"mention-max-size": 500,
"mention-sweep-freq": 30000000000,
"mention-ttl": 300000000000,
"notification-max-size": 500,
"notification-sweep-freq": 30000000000,
"notification-ttl": 300000000000,
"emoji-category-sweep-freq": 60000000000,
"emoji-category-ttl": 1800000000000,
"emoji-max-size": 2000,
"emoji-sweep-freq": 60000000000,
"emoji-ttl": 1800000000000,
"follow-max-size": 2000,
"follow-request-max-size": 2000,
"follow-request-sweep-freq": 60000000000,
"follow-request-ttl": 1800000000000,
"follow-sweep-freq": 60000000000,
"follow-ttl": 1800000000000,
"media-max-size": 1000,
"media-sweep-freq": 60000000000,
"media-ttl": 1800000000000,
"mention-max-size": 2000,
"mention-sweep-freq": 60000000000,
"mention-ttl": 1800000000000,
"notification-max-size": 1000,
"notification-sweep-freq": 60000000000,
"notification-ttl": 1800000000000,
"report-max-size": 100,
"report-sweep-freq": 30000000000,
"report-ttl": 300000000000,
"status-max-size": 500,
"status-sweep-freq": 30000000000,
"status-ttl": 300000000000,
"tombstone-max-size": 100,
"tombstone-sweep-freq": 30000000000,
"tombstone-ttl": 300000000000,
"user-max-size": 100,
"user-sweep-freq": 30000000000,
"user-ttl": 300000000000,
"report-sweep-freq": 60000000000,
"report-ttl": 1800000000000,
"status-fave-max-size": 2000,
"status-fave-sweep-freq": 60000000000,
"status-fave-ttl": 1800000000000,
"status-max-size": 2000,
"status-sweep-freq": 60000000000,
"status-ttl": 1800000000000,
"tombstone-max-size": 500,
"tombstone-sweep-freq": 60000000000,
"tombstone-ttl": 1800000000000,
"user-max-size": 500,
"user-sweep-freq": 60000000000,
"user-ttl": 1800000000000,
"webfinger-max-size": 250,
"webfinger-sweep-freq": 900000000000,
"webfinger-ttl": 86400000000000
}
},
"visibility-max-size": 2000,
"visibility-sweep-freq": 60000000000,
"visibility-ttl": 1800000000000
},
"config-path": "internal/config/testdata/test.yaml",
"db-address": ":memory:",