mirror of
1
Fork 0

[performance] overhaul struct (+ result) caching library for simplicity, performance and multiple-result lookups (#2535)

* rewrite cache library as codeberg.org/gruf/go-structr, implement in gotosocial

* use actual go-structr release version (not just commit hash)

* revert go toolchain changes (damn you go for auto changing this)

* fix go mod woes

* ensure %w is used in calls to errs.Appendf()

* fix error checking

* fix possible panic

* remove unnecessary start/stop functions, move to main Cache{} struct, add note regarding which caches require start/stop

* fix copy-paste artifact... 😇

* fix all comment copy-paste artifacts

* remove dropID() function, now we can just use slices.DeleteFunc()

* use util.Deduplicate() instead of collate(), move collate to util

* move orderByIDs() to util package and "generify"

* add a util.DeleteIf() function, use this to delete entries on failed population

* use slices.DeleteFunc() instead of util.DeleteIf() (i had the logic mixed up in my head somehow lol)

* add note about how collate differs from deduplicate
This commit is contained in:
kim 2024-01-19 12:57:29 +00:00 committed by GitHub
parent 67e11a1a61
commit 7ec1e1332e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 4038 additions and 2711 deletions

View File

@ -133,7 +133,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
time.Time{}, // start time.Time{}, // start
time.Minute, // freq time.Minute, // freq
func(context.Context, time.Time) { func(context.Context, time.Time) {
state.Caches.Sweep(80) state.Caches.Sweep(60)
}, },
) )

1
go.mod
View File

@ -18,6 +18,7 @@ require (
codeberg.org/gruf/go-runners v1.6.2 codeberg.org/gruf/go-runners v1.6.2
codeberg.org/gruf/go-sched v1.2.3 codeberg.org/gruf/go-sched v1.2.3
codeberg.org/gruf/go-store/v2 v2.2.4 codeberg.org/gruf/go-store/v2 v2.2.4
codeberg.org/gruf/go-structr v0.1.1
codeberg.org/superseriousbusiness/exif-terminator v0.7.0 codeberg.org/superseriousbusiness/exif-terminator v0.7.0
github.com/DmitriyVTitov/size v1.5.0 github.com/DmitriyVTitov/size v1.5.0
github.com/KimMachineGun/automemlimit v0.4.0 github.com/KimMachineGun/automemlimit v0.4.0

2
go.sum
View File

@ -70,6 +70,8 @@ codeberg.org/gruf/go-sched v1.2.3 h1:H5ViDxxzOBR3uIyGBCf0eH8b1L8wMybOXcdtUUTXZHk
codeberg.org/gruf/go-sched v1.2.3/go.mod h1:vT9uB6KWFIIwnG9vcPY2a0alYNoqdL1mSzRM8I+PK7A= codeberg.org/gruf/go-sched v1.2.3/go.mod h1:vT9uB6KWFIIwnG9vcPY2a0alYNoqdL1mSzRM8I+PK7A=
codeberg.org/gruf/go-store/v2 v2.2.4 h1:8HO1Jh2gg7boQKA3hsDAIXd9zwieu5uXwDXEcTOD9js= codeberg.org/gruf/go-store/v2 v2.2.4 h1:8HO1Jh2gg7boQKA3hsDAIXd9zwieu5uXwDXEcTOD9js=
codeberg.org/gruf/go-store/v2 v2.2.4/go.mod h1:zI4VWe5CpXAktYMtaBMrgA5QmO0sQH53LBRvfn1huys= codeberg.org/gruf/go-store/v2 v2.2.4/go.mod h1:zI4VWe5CpXAktYMtaBMrgA5QmO0sQH53LBRvfn1huys=
codeberg.org/gruf/go-structr v0.1.1 h1:nR6EcZjXn+oby2nH1Mi6i8S5GWhyjUknkQMXsjbbK0g=
codeberg.org/gruf/go-structr v0.1.1/go.mod h1:OBajB6wcz0BbX0Ns88w2rdUF52rgIej471NJgV0GCW4=
codeberg.org/superseriousbusiness/exif-terminator v0.7.0 h1:Y6VApSXhKqExG0H2hZ2JelRK4xmWdjDQjn13CpEfzko= codeberg.org/superseriousbusiness/exif-terminator v0.7.0 h1:Y6VApSXhKqExG0H2hZ2JelRK4xmWdjDQjn13CpEfzko=
codeberg.org/superseriousbusiness/exif-terminator v0.7.0/go.mod h1:gCWKduudUWFzsnixoMzu0FYVdxHWG+AbXnZ50DqxsUE= codeberg.org/superseriousbusiness/exif-terminator v0.7.0/go.mod h1:gCWKduudUWFzsnixoMzu0FYVdxHWG+AbXnZ50DqxsUE=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=

30
internal/cache/ap.go vendored
View File

@ -1,30 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package cache
type APCaches struct{}
// Init will initialize all the ActivityPub caches in this collection.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *APCaches) Init() {}
// Start will attempt to start all of the ActivityPub caches, or panic.
func (c *APCaches) Start() {}
// Stop will attempt to stop all of the ActivityPub caches, or panic.
func (c *APCaches) Stop() {}

View File

@ -18,8 +18,9 @@
package cache package cache
import ( import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache/headerfilter" "github.com/superseriousbusiness/gotosocial/internal/cache/headerfilter"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
) )
@ -49,198 +50,59 @@ type Caches struct {
func (c *Caches) Init() { func (c *Caches) Init() {
log.Infof(nil, "init: %p", c) log.Infof(nil, "init: %p", c)
c.GTS.Init() c.initAccount()
c.Visibility.Init() c.initAccountNote()
c.initApplication()
// Setup cache invalidate hooks. c.initBlock()
// !! READ THE METHOD COMMENT c.initBlockIDs()
c.setuphooks() c.initBoostOfIDs()
c.initDomainAllow()
c.initDomainBlock()
c.initEmoji()
c.initEmojiCategory()
c.initFollow()
c.initFollowIDs()
c.initFollowRequest()
c.initFollowRequestIDs()
c.initInReplyToIDs()
c.initInstance()
c.initList()
c.initListEntry()
c.initMarker()
c.initMedia()
c.initMention()
c.initNotification()
c.initPoll()
c.initPollVote()
c.initPollVoteIDs()
c.initReport()
c.initStatus()
c.initStatusFave()
c.initTag()
c.initThreadMute()
c.initStatusFaveIDs()
c.initTombstone()
c.initUser()
c.initWebfinger()
c.initVisibility()
} }
// Start will start both the GTS and AP cache collections. // Start will start any caches that require a background
// routine, which usually means any kind of TTL caches.
func (c *Caches) Start() { func (c *Caches) Start() {
log.Infof(nil, "start: %p", c) log.Infof(nil, "start: %p", c)
c.GTS.Start() tryUntil("starting *gtsmodel.Webfinger cache", 5, func() bool {
c.Visibility.Start() return c.GTS.Webfinger.Start(5 * time.Minute)
})
} }
// Stop will stop both the GTS and AP cache collections. // Stop will stop any caches that require a background
// routine, which usually means any kind of TTL caches.
func (c *Caches) Stop() { func (c *Caches) Stop() {
log.Infof(nil, "stop: %p", c) log.Infof(nil, "stop: %p", c)
c.GTS.Stop() tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.GTS.Webfinger.Stop)
c.Visibility.Stop()
}
// setuphooks sets necessary cache invalidation hooks between caches,
// as an invalidation indicates a database INSERT / UPDATE / DELETE.
// NOTE THEY ARE ONLY CALLED WHEN THE ITEM IS IN THE CACHE, SO FOR
// HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE.
func (c *Caches) setuphooks() {
c.GTS.Account().SetInvalidateCallback(func(account *gtsmodel.Account) {
// Invalidate account ID cached visibility.
c.Visibility.Invalidate("ItemID", account.ID)
c.Visibility.Invalidate("RequesterID", account.ID)
// Invalidate this account's
// following / follower lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs().InvalidateAll(
">"+account.ID,
"l>"+account.ID,
"<"+account.ID,
"l<"+account.ID,
)
// Invalidate this account's
// follow requesting / request lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs().InvalidateAll(
">"+account.ID,
"<"+account.ID,
)
// Invalidate this account's block lists.
c.GTS.BlockIDs().Invalidate(account.ID)
})
c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) {
// Invalidate block origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.AccountID)
c.Visibility.Invalidate("RequesterID", block.AccountID)
// Invalidate block target account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.TargetAccountID)
c.Visibility.Invalidate("RequesterID", block.TargetAccountID)
// Invalidate source account's block lists.
c.GTS.BlockIDs().Invalidate(block.AccountID)
})
c.GTS.EmojiCategory().SetInvalidateCallback(func(category *gtsmodel.EmojiCategory) {
// Invalidate any emoji in this category.
c.GTS.Emoji().Invalidate("CategoryID", category.ID)
})
c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) {
// Invalidate follow request with this same ID.
c.GTS.FollowRequest().Invalidate("ID", follow.ID)
// Invalidate any related list entries.
c.GTS.ListEntry().Invalidate("FollowID", follow.ID)
// Invalidate follow origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.AccountID)
c.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow target account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.TargetAccountID)
c.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
// Invalidate source account's following
// lists, and destination's follwer lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs().InvalidateAll(
">"+follow.AccountID,
"l>"+follow.AccountID,
"<"+follow.AccountID,
"l<"+follow.AccountID,
"<"+follow.TargetAccountID,
"l<"+follow.TargetAccountID,
">"+follow.TargetAccountID,
"l>"+follow.TargetAccountID,
)
})
c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) {
// Invalidate follow with this same ID.
c.GTS.Follow().Invalidate("ID", followReq.ID)
// Invalidate source account's followreq
// lists, and destinations follow req lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs().InvalidateAll(
">"+followReq.AccountID,
"<"+followReq.AccountID,
">"+followReq.TargetAccountID,
"<"+followReq.TargetAccountID,
)
})
c.GTS.List().SetInvalidateCallback(func(list *gtsmodel.List) {
// Invalidate all cached entries of this list.
c.GTS.ListEntry().Invalidate("ListID", list.ID)
})
c.GTS.Media().SetInvalidateCallback(func(media *gtsmodel.MediaAttachment) {
if *media.Avatar || *media.Header {
// Invalidate cache of attaching account.
c.GTS.Account().Invalidate("ID", media.AccountID)
}
if media.StatusID != "" {
// Invalidate cache of attaching status.
c.GTS.Status().Invalidate("ID", media.StatusID)
}
})
c.GTS.Poll().SetInvalidateCallback(func(poll *gtsmodel.Poll) {
// Invalidate all cached votes of this poll.
c.GTS.PollVote().Invalidate("PollID", poll.ID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs().Invalidate(poll.ID)
})
c.GTS.PollVote().SetInvalidateCallback(func(vote *gtsmodel.PollVote) {
// Invalidate cached poll (contains no. votes).
c.GTS.Poll().Invalidate("ID", vote.PollID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs().Invalidate(vote.PollID)
})
c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) {
// Invalidate status ID cached visibility.
c.Visibility.Invalidate("ItemID", status.ID)
for _, id := range status.AttachmentIDs {
// Invalidate each media by the IDs we're aware of.
// This must be done as the status table is aware of
// the media IDs in use before the media table is
// aware of the status ID they are linked to.
//
// c.GTS.Media().Invalidate("StatusID") will not work.
c.GTS.Media().Invalidate("ID", id)
}
if status.BoostOfID != "" {
// Invalidate boost ID list of the original status.
c.GTS.BoostOfIDs().Invalidate(status.BoostOfID)
}
if status.InReplyToID != "" {
// Invalidate in reply to ID list of original status.
c.GTS.InReplyToIDs().Invalidate(status.InReplyToID)
}
if status.PollID != "" {
// Invalidate cache of attached poll ID.
c.GTS.Poll().Invalidate("ID", status.PollID)
}
})
c.GTS.StatusFave().SetInvalidateCallback(func(fave *gtsmodel.StatusFave) {
// Invalidate status fave ID list for this status.
c.GTS.StatusFaveIDs().Invalidate(fave.StatusID)
})
c.GTS.User().SetInvalidateCallback(func(user *gtsmodel.User) {
// Invalidate local account ID cached visibility.
c.Visibility.Invalidate("ItemID", user.AccountID)
c.Visibility.Invalidate("RequesterID", user.AccountID)
})
} }
// Sweep will sweep all the available caches to ensure none // Sweep will sweep all the available caches to ensure none
@ -250,30 +112,30 @@ func (c *Caches) setuphooks() {
// require an eviction on every single write, which adds // require an eviction on every single write, which adds
// significant overhead to all cache writes. // significant overhead to all cache writes.
func (c *Caches) Sweep(threshold float64) { func (c *Caches) Sweep(threshold float64) {
c.GTS.Account().Trim(threshold) c.GTS.Account.Trim(threshold)
c.GTS.AccountNote().Trim(threshold) c.GTS.AccountNote.Trim(threshold)
c.GTS.Block().Trim(threshold) c.GTS.Block.Trim(threshold)
c.GTS.BlockIDs().Trim(threshold) c.GTS.BlockIDs.Trim(threshold)
c.GTS.Emoji().Trim(threshold) c.GTS.Emoji.Trim(threshold)
c.GTS.EmojiCategory().Trim(threshold) c.GTS.EmojiCategory.Trim(threshold)
c.GTS.Follow().Trim(threshold) c.GTS.Follow.Trim(threshold)
c.GTS.FollowIDs().Trim(threshold) c.GTS.FollowIDs.Trim(threshold)
c.GTS.FollowRequest().Trim(threshold) c.GTS.FollowRequest.Trim(threshold)
c.GTS.FollowRequestIDs().Trim(threshold) c.GTS.FollowRequestIDs.Trim(threshold)
c.GTS.Instance().Trim(threshold) c.GTS.Instance.Trim(threshold)
c.GTS.List().Trim(threshold) c.GTS.List.Trim(threshold)
c.GTS.ListEntry().Trim(threshold) c.GTS.ListEntry.Trim(threshold)
c.GTS.Marker().Trim(threshold) c.GTS.Marker.Trim(threshold)
c.GTS.Media().Trim(threshold) c.GTS.Media.Trim(threshold)
c.GTS.Mention().Trim(threshold) c.GTS.Mention.Trim(threshold)
c.GTS.Notification().Trim(threshold) c.GTS.Notification.Trim(threshold)
c.GTS.Poll().Trim(threshold) c.GTS.Poll.Trim(threshold)
c.GTS.Report().Trim(threshold) c.GTS.Report.Trim(threshold)
c.GTS.Status().Trim(threshold) c.GTS.Status.Trim(threshold)
c.GTS.StatusFave().Trim(threshold) c.GTS.StatusFave.Trim(threshold)
c.GTS.Tag().Trim(threshold) c.GTS.Tag.Trim(threshold)
c.GTS.ThreadMute().Trim(threshold) c.GTS.ThreadMute.Trim(threshold)
c.GTS.Tombstone().Trim(threshold) c.GTS.Tombstone.Trim(threshold)
c.GTS.User().Trim(threshold) c.GTS.User.Trim(threshold)
c.Visibility.Trim(threshold) c.Visibility.Trim(threshold)
} }

1071
internal/cache/db.go vendored Normal file

File diff suppressed because it is too large Load Diff

1119
internal/cache/gts.go vendored

File diff suppressed because it is too large Load Diff

192
internal/cache/invalidate.go vendored Normal file
View File

@ -0,0 +1,192 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package cache
import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Below are cache invalidation hooks between other caches,
// as an invalidation indicates a database INSERT / UPDATE / DELETE.
// NOTE THEY ARE ONLY CALLED WHEN THE ITEM IS IN THE CACHE, SO FOR
// HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE.
func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) {
// Invalidate account ID cached visibility.
c.Visibility.Invalidate("ItemID", account.ID)
c.Visibility.Invalidate("RequesterID", account.ID)
// Invalidate this account's
// following / follower lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs.InvalidateAll(
">"+account.ID,
"l>"+account.ID,
"<"+account.ID,
"l<"+account.ID,
)
// Invalidate this account's
// follow requesting / request lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs.InvalidateAll(
">"+account.ID,
"<"+account.ID,
)
// Invalidate this account's block lists.
c.GTS.BlockIDs.Invalidate(account.ID)
}
func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {
// Invalidate block origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.AccountID)
c.Visibility.Invalidate("RequesterID", block.AccountID)
// Invalidate block target account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.TargetAccountID)
c.Visibility.Invalidate("RequesterID", block.TargetAccountID)
// Invalidate source account's block lists.
c.GTS.BlockIDs.Invalidate(block.AccountID)
}
func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) {
// Invalidate any emoji in this category.
c.GTS.Emoji.Invalidate("CategoryID", category.ID)
}
func (c *Caches) OnInvalidateFollow(follow *gtsmodel.Follow) {
// Invalidate follow request with this same ID.
c.GTS.FollowRequest.Invalidate("ID", follow.ID)
// Invalidate any related list entries.
c.GTS.ListEntry.Invalidate("FollowID", follow.ID)
// Invalidate follow origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.AccountID)
c.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow target account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.TargetAccountID)
c.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
// Invalidate source account's following
// lists, and destination's follwer lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs.InvalidateAll(
">"+follow.AccountID,
"l>"+follow.AccountID,
"<"+follow.AccountID,
"l<"+follow.AccountID,
"<"+follow.TargetAccountID,
"l<"+follow.TargetAccountID,
">"+follow.TargetAccountID,
"l>"+follow.TargetAccountID,
)
}
func (c *Caches) OnInvalidateFollowRequest(followReq *gtsmodel.FollowRequest) {
// Invalidate follow with this same ID.
c.GTS.Follow.Invalidate("ID", followReq.ID)
// Invalidate source account's followreq
// lists, and destinations follow req lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs.InvalidateAll(
">"+followReq.AccountID,
"<"+followReq.AccountID,
">"+followReq.TargetAccountID,
"<"+followReq.TargetAccountID,
)
}
func (c *Caches) OnInvalidateList(list *gtsmodel.List) {
// Invalidate all cached entries of this list.
c.GTS.ListEntry.Invalidate("ListID", list.ID)
}
func (c *Caches) OnInvalidateMedia(media *gtsmodel.MediaAttachment) {
if (media.Avatar != nil && *media.Avatar) ||
(media.Header != nil && *media.Header) {
// Invalidate cache of attaching account.
c.GTS.Account.Invalidate("ID", media.AccountID)
}
if media.StatusID != "" {
// Invalidate cache of attaching status.
c.GTS.Status.Invalidate("ID", media.StatusID)
}
}
func (c *Caches) OnInvalidatePoll(poll *gtsmodel.Poll) {
// Invalidate all cached votes of this poll.
c.GTS.PollVote.Invalidate("PollID", poll.ID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs.Invalidate(poll.ID)
}
func (c *Caches) OnInvalidatePollVote(vote *gtsmodel.PollVote) {
// Invalidate cached poll (contains no. votes).
c.GTS.Poll.Invalidate("ID", vote.PollID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs.Invalidate(vote.PollID)
}
func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) {
// Invalidate status ID cached visibility.
c.Visibility.Invalidate("ItemID", status.ID)
for _, id := range status.AttachmentIDs {
// Invalidate each media by the IDs we're aware of.
// This must be done as the status table is aware of
// the media IDs in use before the media table is
// aware of the status ID they are linked to.
//
// c.GTS.Media().Invalidate("StatusID") will not work.
c.GTS.Media.Invalidate("ID", id)
}
if status.BoostOfID != "" {
// Invalidate boost ID list of the original status.
c.GTS.BoostOfIDs.Invalidate(status.BoostOfID)
}
if status.InReplyToID != "" {
// Invalidate in reply to ID list of original status.
c.GTS.InReplyToIDs.Invalidate(status.InReplyToID)
}
if status.PollID != "" {
// Invalidate cache of attached poll ID.
c.GTS.Poll.Invalidate("ID", status.PollID)
}
}
func (c *Caches) OnInvalidateStatusFave(fave *gtsmodel.StatusFave) {
// Invalidate status fave ID list for this status.
c.GTS.StatusFaveIDs.Invalidate(fave.StatusID)
}
func (c *Caches) OnInvalidateUser(user *gtsmodel.User) {
// Invalidate local account ID cached visibility.
c.Visibility.Invalidate("ItemID", user.AccountID)
c.Visibility.Invalidate("RequesterID", user.AccountID)
}

View File

@ -18,18 +18,16 @@
package cache package cache
import ( import (
"codeberg.org/gruf/go-cache/v3/result" "codeberg.org/gruf/go-structr"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
) )
type VisibilityCache struct { type VisibilityCache struct {
*result.Cache[*CachedVisibility] structr.Cache[*CachedVisibility]
} }
// Init will initialize the visibility cache in this collection. func (c *Caches) initVisibility() {
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *VisibilityCache) Init() {
// Calculate maximum cache size. // Calculate maximum cache size.
cap := calculateResultCacheMax( cap := calculateResultCacheMax(
sizeofVisibility(), // model in-mem size. sizeofVisibility(), // model in-mem size.
@ -38,25 +36,22 @@ func (c *VisibilityCache) Init() {
log.Infof(nil, "Visibility cache size = %d", cap) log.Infof(nil, "Visibility cache size = %d", cap)
c.Cache = result.New([]result.Lookup{ copyF := func(v1 *CachedVisibility) *CachedVisibility {
{Name: "ItemID", Multi: true},
{Name: "RequesterID", Multi: true},
{Name: "Type.RequesterID.ItemID"},
}, func(v1 *CachedVisibility) *CachedVisibility {
v2 := new(CachedVisibility) v2 := new(CachedVisibility)
*v2 = *v1 *v2 = *v1
return v2 return v2
}, cap) }
c.Cache.IgnoreErrors(ignoreErrors) c.Visibility.Init(structr.Config[*CachedVisibility]{
} Indices: []structr.IndexConfig{
{Fields: "ItemID", Multiple: true},
// Start will attempt to start the visibility cache, or panic. {Fields: "RequesterID", Multiple: true},
func (c *VisibilityCache) Start() { {Fields: "Type,RequesterID,ItemID"},
} },
MaxSize: cap,
// Stop will attempt to stop the visibility cache, or panic. IgnoreErr: ignoreErrors,
func (c *VisibilityCache) Stop() { CopyValue: copyF,
})
} }
// VisibilityType represents a visibility lookup type. // VisibilityType represents a visibility lookup type.

View File

@ -116,7 +116,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
return a.getAccount( return a.getAccount(
ctx, ctx,
"Username.Domain", "Username,Domain",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
q := a.db.NewSelect(). q := a.db.NewSelect().
Model(account) Model(account)
@ -224,7 +224,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) { func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback // Fetch account from database cache with loader callback
account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) { account, err := a.state.Caches.GTS.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account var account gtsmodel.Account
// Not cached! Perform database query // Not cached! Perform database query
@ -325,7 +325,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
} }
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error { func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error {
return a.state.Caches.GTS.Account().Store(account, func() error { return a.state.Caches.GTS.Account.Store(account, func() error {
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
@ -354,7 +354,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
return a.state.Caches.GTS.Account().Store(account, func() error { return a.state.Caches.GTS.Account.Store(account, func() error {
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
@ -393,7 +393,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
} }
func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
defer a.state.Caches.GTS.Account().Invalidate("ID", id) defer a.state.Caches.GTS.Account.Invalidate("ID", id)
// Load account into cache before attempting a delete, // Load account into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate // as we need it cached in order to trigger the invalidate
@ -635,6 +635,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
return nil, err return nil, err
} }
if len(statusIDs) == 0 {
return nil, db.ErrNoEntries
}
// If we're paging up, we still want statuses // If we're paging up, we still want statuses
// to be sorted by ID desc, so reverse ids slice. // to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing // https://zchee.github.io/golang-wiki/SliceTricks/#reversing
@ -644,7 +648,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
} }
} }
return a.statusesFromIDs(ctx, statusIDs) return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
} }
func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) { func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) {
@ -662,7 +666,11 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri
return nil, err return nil, err
} }
return a.statusesFromIDs(ctx, statusIDs) if len(statusIDs) == 0 {
return nil, db.ErrNoEntries
}
return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
} }
func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) { func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) {
@ -710,29 +718,9 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
return nil, err return nil, err
} }
return a.statusesFromIDs(ctx, statusIDs)
}
func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {
// Catch case of no statuses early
if len(statusIDs) == 0 { if len(statusIDs) == 0 {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries
} }
// Allocate return slice (will be at most len statusIDS) return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch from status from database by ID
status, err := a.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status %q: %v", id, err)
continue
}
// Append to return slice
statuses = append(statuses, status)
}
return statuses, nil
} }

View File

@ -53,7 +53,7 @@ func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID s
} }
func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) { func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) {
return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) { return a.state.Caches.GTS.Application.LoadOne(lookup, func() (*gtsmodel.Application, error) {
var app gtsmodel.Application var app gtsmodel.Application
// Not cached! Perform database query. // Not cached! Perform database query.
@ -66,7 +66,7 @@ func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQue
} }
func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error { func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error {
return a.state.Caches.GTS.Application().Store(app, func() error { return a.state.Caches.GTS.Application.Store(app, func() error {
_, err := a.db.NewInsert().Model(app).Exec(ctx) _, err := a.db.NewInsert().Model(app).Exec(ctx)
return err return err
}) })
@ -91,7 +91,7 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI
// //
// Clear application from the cache. // Clear application from the cache.
a.state.Caches.GTS.Application().Invalidate("ClientID", clientID) a.state.Caches.GTS.Application.Invalidate("ClientID", clientID)
return nil return nil
} }

View File

@ -258,7 +258,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
state: state, state: state,
}, },
Tag: &tagDB{ Tag: &tagDB{
conn: db, db: db,
state: state, state: state,
}, },
Thread: &threadDB{ Thread: &threadDB{

View File

@ -51,7 +51,7 @@ func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.Domain
} }
// Clear the domain allow cache (for later reload) // Clear the domain allow cache (for later reload)
d.state.Caches.GTS.DomainAllow().Clear() d.state.Caches.GTS.DomainAllow.Clear()
return nil return nil
} }
@ -126,7 +126,7 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
} }
// Clear the domain allow cache (for later reload) // Clear the domain allow cache (for later reload)
d.state.Caches.GTS.DomainAllow().Clear() d.state.Caches.GTS.DomainAllow.Clear()
return nil return nil
} }
@ -147,7 +147,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
} }
// Clear the domain block cache (for later reload) // Clear the domain block cache (for later reload)
d.state.Caches.GTS.DomainBlock().Clear() d.state.Caches.GTS.DomainBlock.Clear()
return nil return nil
} }
@ -222,7 +222,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
} }
// Clear the domain block cache (for later reload) // Clear the domain block cache (for later reload)
d.state.Caches.GTS.DomainBlock().Clear() d.state.Caches.GTS.DomainBlock.Clear()
return nil return nil
} }
@ -241,7 +241,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
} }
// Check the cache for an explicit domain allow (hydrating the cache with callback if necessary). // Check the cache for an explicit domain allow (hydrating the cache with callback if necessary).
explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) { explicitAllow, err := d.state.Caches.GTS.DomainAllow.Matches(domain, func() ([]string, error) {
var domains []string var domains []string
// Scan list of all explicitly allowed domains from DB // Scan list of all explicitly allowed domains from DB
@ -259,7 +259,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
} }
// Check the cache for a domain block (hydrating the cache with callback if necessary) // Check the cache for a domain block (hydrating the cache with callback if necessary)
explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) { explicitBlock, err := d.state.Caches.GTS.DomainBlock.Matches(domain, func() ([]string, error) {
var domains []string var domains []string
// Scan list of all blocked domains from DB // Scan list of all blocked domains from DB

View File

@ -21,6 +21,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"slices"
"strings" "strings"
"time" "time"
@ -30,6 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect"
) )
@ -40,7 +42,7 @@ type emojiDB struct {
} }
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error { func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error {
return e.state.Caches.GTS.Emoji().Store(emoji, func() error { return e.state.Caches.GTS.Emoji.Store(emoji, func() error {
_, err := e.db.NewInsert().Model(emoji).Exec(ctx) _, err := e.db.NewInsert().Model(emoji).Exec(ctx)
return err return err
}) })
@ -54,7 +56,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
} }
// Update the emoji model in the database. // Update the emoji model in the database.
return e.state.Caches.GTS.Emoji().Store(emoji, func() error { return e.state.Caches.GTS.Emoji.Store(emoji, func() error {
_, err := e.db. _, err := e.db.
NewUpdate(). NewUpdate().
Model(emoji). Model(emoji).
@ -74,21 +76,21 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
defer func() { defer func() {
// Invalidate cached emoji. // Invalidate cached emoji.
e.state.Caches.GTS. e.state.Caches.GTS.
Emoji(). Emoji.
Invalidate("ID", id) Invalidate("ID", id)
for _, id := range accountIDs { for _, accountID := range accountIDs {
// Invalidate cached account. // Invalidate cached account.
e.state.Caches.GTS. e.state.Caches.GTS.
Account(). Account.
Invalidate("ID", id) Invalidate("ID", accountID)
} }
for _, id := range statusIDs { for _, statusID := range statusIDs {
// Invalidate cached account. // Invalidate cached account.
e.state.Caches.GTS. e.state.Caches.GTS.
Status(). Status.
Invalidate("ID", id) Invalidate("ID", statusID)
} }
}() }()
@ -129,26 +131,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
return err return err
} }
for _, id := range statusIDs { for _, statusID := range statusIDs {
var emojiIDs []string var emojiIDs []string
// Select statuses with ID. // Select statuses with ID.
if _, err := tx.NewSelect(). if _, err := tx.NewSelect().
Table("statuses"). Table("statuses").
Column("emojis"). Column("emojis").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), statusID).
Exec(ctx); err != nil && Exec(ctx); err != nil &&
err != sql.ErrNoRows { err != sql.ErrNoRows {
return err return err
} }
// Drop ID from account emojis. // Delete all instances of this emoji ID from status emojis.
emojiIDs = dropID(emojiIDs, id) emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool {
return emojiID == id
})
// Update status emoji IDs. // Update status emoji IDs.
if _, err := tx.NewUpdate(). if _, err := tx.NewUpdate().
Table("statuses"). Table("statuses").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), statusID).
Set("emojis = ?", emojiIDs). Set("emojis = ?", emojiIDs).
Exec(ctx); err != nil && Exec(ctx); err != nil &&
err != sql.ErrNoRows { err != sql.ErrNoRows {
@ -156,26 +160,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
} }
} }
for _, id := range accountIDs { for _, accountID := range accountIDs {
var emojiIDs []string var emojiIDs []string
// Select account with ID. // Select account with ID.
if _, err := tx.NewSelect(). if _, err := tx.NewSelect().
Table("accounts"). Table("accounts").
Column("emojis"). Column("emojis").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), accountID).
Exec(ctx); err != nil && Exec(ctx); err != nil &&
err != sql.ErrNoRows { err != sql.ErrNoRows {
return err return err
} }
// Drop ID from account emojis. // Delete all instances of this emoji ID from account emojis.
emojiIDs = dropID(emojiIDs, id) emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool {
return emojiID == id
})
// Update account emoji IDs. // Update account emoji IDs.
if _, err := tx.NewUpdate(). if _, err := tx.NewUpdate().
Table("accounts"). Table("accounts").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), accountID).
Set("emojis = ?", emojiIDs). Set("emojis = ?", emojiIDs).
Exec(ctx); err != nil && Exec(ctx); err != nil &&
err != sql.ErrNoRows { err != sql.ErrNoRows {
@ -431,7 +437,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) { func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
"Shortcode.Domain", "Shortcode,Domain",
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
q := e.db. q := e.db.
NewSelect(). NewSelect().
@ -468,7 +474,7 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string
} }
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error { func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error {
return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error { return e.state.Caches.GTS.EmojiCategory.Store(emojiCategory, func() error {
_, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx) _, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx)
return err return err
}) })
@ -520,7 +526,7 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts
func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) { func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) {
// Fetch emoji from database cache with loader callback // Fetch emoji from database cache with loader callback
emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) { emoji, err := e.state.Caches.GTS.Emoji.LoadOne(lookup, func() (*gtsmodel.Emoji, error) {
var emoji gtsmodel.Emoji var emoji gtsmodel.Emoji
// Not cached! Perform database query // Not cached! Perform database query
@ -568,28 +574,72 @@ func (e *emojiDB) PopulateEmoji(ctx context.Context, emoji *gtsmodel.Emoji) erro
return errs.Combine() return errs.Combine()
} }
func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) { func (e *emojiDB) GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error) {
if len(emojiIDs) == 0 { if len(ids) == 0 {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries
} }
emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs)) // Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range emojiIDs { // Load all emoji IDs via cache loader callbacks.
emoji, err := e.GetEmojiByID(ctx, id) emojis, err := e.state.Caches.GTS.Emoji.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached emoji loader function.
func() ([]*gtsmodel.Emoji, error) {
// Preallocate expected length of uncached emojis.
emojis := make([]*gtsmodel.Emoji, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := e.db.NewSelect().
Model(&emojis).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return emojis, nil
},
)
if err != nil { if err != nil {
log.Errorf(ctx, "emojisFromIDs: error getting emoji %q: %v", id, err) return nil, err
continue
} }
emojis = append(emojis, emoji) // Reorder the emojis by their
// IDs to ensure in correct order.
getID := func(e *gtsmodel.Emoji) string { return e.ID }
util.OrderBy(emojis, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return emojis, nil
} }
// Populate all loaded emojis, removing those we fail to
// populate (removes needing so many nil checks everywhere).
emojis = slices.DeleteFunc(emojis, func(emoji *gtsmodel.Emoji) bool {
if err := e.PopulateEmoji(ctx, emoji); err != nil {
log.Errorf(ctx, "error populating emoji %s: %v", emoji.ID, err)
return true
}
return false
})
return emojis, nil return emojis, nil
} }
func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) { func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) {
return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) { return e.state.Caches.GTS.EmojiCategory.LoadOne(lookup, func() (*gtsmodel.EmojiCategory, error) {
var category gtsmodel.EmojiCategory var category gtsmodel.EmojiCategory
// Not cached! Perform database query // Not cached! Perform database query
@ -601,36 +651,51 @@ func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery f
}, keyParts...) }, keyParts...)
} }
func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) { func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error) {
if len(emojiCategoryIDs) == 0 { if len(ids) == 0 {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries
} }
emojiCategories := make([]*gtsmodel.EmojiCategory, 0, len(emojiCategoryIDs)) // Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range emojiCategoryIDs { // Load all category IDs via cache loader callbacks.
emojiCategory, err := e.GetEmojiCategory(ctx, id) categories, err := e.state.Caches.GTS.EmojiCategory.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached emoji loader function.
func() ([]*gtsmodel.EmojiCategory, error) {
// Preallocate expected length of uncached categories.
categories := make([]*gtsmodel.EmojiCategory, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := e.db.NewSelect().
Model(&categories).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return categories, nil
},
)
if err != nil { if err != nil {
log.Errorf(ctx, "error getting emoji category %q: %v", id, err) return nil, err
continue
} }
emojiCategories = append(emojiCategories, emojiCategory) // Reorder the categories by their
} // IDs to ensure in correct order.
getID := func(c *gtsmodel.EmojiCategory) string { return c.ID }
util.OrderBy(categories, ids, getID)
return emojiCategories, nil return categories, nil
}
// dropIDs drops given ID string from IDs slice.
func dropID(ids []string, id string) []string {
for i := 0; i < len(ids); {
if ids[i] == id {
// Remove this reference.
copy(ids[i:], ids[i+1:])
ids = ids[:len(ids)-1]
continue
}
i++
}
return ids
} }

View File

@ -143,7 +143,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) { func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) {
// Fetch instance from database cache with loader callback // Fetch instance from database cache with loader callback
instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) { instance, err := i.state.Caches.GTS.Instance.LoadOne(lookup, func() (*gtsmodel.Instance, error) {
var instance gtsmodel.Instance var instance gtsmodel.Instance
// Not cached! Perform database query. // Not cached! Perform database query.
@ -219,7 +219,7 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
} }
return i.state.Caches.GTS.Instance().Store(instance, func() error { return i.state.Caches.GTS.Instance.Store(instance, func() error {
_, err := i.db.NewInsert().Model(instance).Exec(ctx) _, err := i.db.NewInsert().Model(instance).Exec(ctx)
return err return err
}) })
@ -239,7 +239,7 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
return i.state.Caches.GTS.Instance().Store(instance, func() error { return i.state.Caches.GTS.Instance.Store(instance, func() error {
_, err := i.db. _, err := i.db.
NewUpdate(). NewUpdate().
Model(instance). Model(instance).

View File

@ -21,6 +21,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"slices"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -56,7 +58,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er
} }
func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { list, err := l.state.Caches.GTS.List.LoadOne(lookup, func() (*gtsmodel.List, error) {
var list gtsmodel.List var list gtsmodel.List
// Not cached! Perform database query. // Not cached! Perform database query.
@ -100,18 +102,8 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]
return nil, nil return nil, nil
} }
// Select each list using its ID to ensure cache used. // Return lists by their IDs.
lists := make([]*gtsmodel.List, 0, len(listIDs)) return l.GetListsByIDs(ctx, listIDs)
for _, id := range listIDs {
list, err := l.state.DB.GetListByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list %q: %v", id, err)
continue
}
lists = append(lists, list)
}
return lists, nil
} }
func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
@ -147,7 +139,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
} }
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
return l.state.Caches.GTS.List().Store(list, func() error { return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewInsert().Model(list).Exec(ctx) _, err := l.db.NewInsert().Model(list).Exec(ctx)
return err return err
}) })
@ -162,7 +154,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
defer func() { defer func() {
// Invalidate all entries for this list ID. // Invalidate all entries for this list ID.
l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID) l.state.Caches.GTS.ListEntry.Invalidate("ListID", list.ID)
// Invalidate this entire list's timeline. // Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil {
@ -170,7 +162,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
} }
}() }()
return l.state.Caches.GTS.List().Store(list, func() error { return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewUpdate(). _, err := l.db.NewUpdate().
Model(list). Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID). Where("? = ?", bun.Ident("list.id"), list.ID).
@ -198,7 +190,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
defer func() { defer func() {
// Invalidate this list from cache. // Invalidate this list from cache.
l.state.Caches.GTS.List().Invalidate("ID", id) l.state.Caches.GTS.List.Invalidate("ID", id)
// Invalidate this entire list's timeline. // Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil {
@ -243,7 +235,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis
} }
func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { listEntry, err := l.state.Caches.GTS.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) {
var listEntry gtsmodel.ListEntry var listEntry gtsmodel.ListEntry
// Not cached! Perform database query. // Not cached! Perform database query.
@ -344,18 +336,128 @@ func (l *listDB) GetListEntries(ctx context.Context,
} }
} }
// Select each list entry using its ID to ensure cache used. // Return list entries by their IDs.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) return l.GetListEntriesByIDs(ctx, entryIDs)
for _, id := range entryIDs { }
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil { func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err) // Preallocate at-worst possible length.
continue uncached := make([]string, 0, len(ids))
// Load all list IDs via cache loader callbacks.
lists, err := l.state.Caches.GTS.List.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
} }
listEntries = append(listEntries, listEntry) }
},
// Uncached list loader function.
func() ([]*gtsmodel.List, error) {
// Preallocate expected length of uncached lists.
lists := make([]*gtsmodel.List, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := l.db.NewSelect().
Model(&lists).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
} }
return listEntries, nil return lists, nil
},
)
if err != nil {
return nil, err
}
// Reorder the lists by their
// IDs to ensure in correct order.
getID := func(l *gtsmodel.List) string { return l.ID }
util.OrderBy(lists, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return lists, nil
}
// Populate all loaded lists, removing those we fail to
// populate (removes needing so many nil checks everywhere).
lists = slices.DeleteFunc(lists, func(list *gtsmodel.List) bool {
if err := l.PopulateList(ctx, list); err != nil {
log.Errorf(ctx, "error populating list %s: %v", list.ID, err)
return true
}
return false
})
return lists, nil
}
func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) {
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all entry IDs via cache loader callbacks.
entries, err := l.state.Caches.GTS.ListEntry.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached entry loader function.
func() ([]*gtsmodel.ListEntry, error) {
// Preallocate expected length of uncached entries.
entries := make([]*gtsmodel.ListEntry, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := l.db.NewSelect().
Model(&entries).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return entries, nil
},
)
if err != nil {
return nil, err
}
// Reorder the entries by their
// IDs to ensure in correct order.
getID := func(e *gtsmodel.ListEntry) string { return e.ID }
util.OrderBy(entries, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return entries, nil
}
// Populate all loaded entries, removing those we fail to
// populate (removes needing so many nil checks everywhere).
entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool {
if err := l.PopulateListEntry(ctx, entry); err != nil {
log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err)
return true
}
return false
})
return entries, nil
} }
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
@ -376,18 +478,8 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)
return nil, nil return nil, nil
} }
// Select each list entry using its ID to ensure cache used. // Return list entries by their IDs.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) return l.GetListEntriesByIDs(ctx, entryIDs)
for _, id := range entryIDs {
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
continue
}
listEntries = append(listEntries, listEntry)
}
return listEntries, nil
} }
func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
@ -409,10 +501,10 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List
func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error { func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error {
defer func() { defer func() {
// Collect unique list IDs from the entries. // Collect unique list IDs from the provided entries.
listIDs := collate(func(i int) string { listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string {
return entries[i].ListID return e.ListID
}, len(entries)) })
for _, id := range listIDs { for _, id := range listIDs {
// Invalidate the timeline for the list this entry belongs to. // Invalidate the timeline for the list this entry belongs to.
@ -426,7 +518,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
return l.db.RunInTx(ctx, func(tx Tx) error { return l.db.RunInTx(ctx, func(tx Tx) error {
for _, entry := range entries { for _, entry := range entries {
entry := entry // rescope entry := entry // rescope
if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error {
_, err := tx. _, err := tx.
NewInsert(). NewInsert().
Model(entry). Model(entry).
@ -459,7 +551,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
defer func() { defer func() {
// Invalidate this list entry upon delete. // Invalidate this list entry upon delete.
l.state.Caches.GTS.ListEntry().Invalidate("ID", id) l.state.Caches.GTS.ListEntry.Invalidate("ID", id)
// Invalidate the timeline for the list this entry belongs to. // Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil { if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil {
@ -514,24 +606,3 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account
return exists, err return exists, err
} }
// collate will collect the values of type T from an expected slice of length 'len',
// passing the expected index to each call of 'get' and deduplicating the end result.
func collate[T comparable](get func(int) T, len int) []T {
ts := make([]T, 0, len)
tm := make(map[T]struct{}, len)
for i := 0; i < len; i++ {
// Get next.
t := get(i)
if _, ok := tm[t]; !ok {
// New value, add
// to map + slice.
ts = append(ts, t)
tm[t] = struct{}{}
}
}
return ts
}

View File

@ -39,8 +39,8 @@ type markerDB struct {
*/ */
func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmodel.MarkerName) (*gtsmodel.Marker, error) { func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmodel.MarkerName) (*gtsmodel.Marker, error) {
marker, err := m.state.Caches.GTS.Marker().Load( marker, err := m.state.Caches.GTS.Marker.LoadOne(
"AccountID.Name", "AccountID,Name",
func() (*gtsmodel.Marker, error) { func() (*gtsmodel.Marker, error) {
var marker gtsmodel.Marker var marker gtsmodel.Marker
@ -52,9 +52,7 @@ func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmode
} }
return &marker, nil return &marker, nil
}, }, accountID, name,
accountID,
name,
) )
if err != nil { if err != nil {
return nil, err // already processed return nil, err // already processed
@ -74,7 +72,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er
marker.Version = prevMarker.Version + 1 marker.Version = prevMarker.Version + 1
} }
return m.state.Caches.GTS.Marker().Store(marker, func() error { return m.state.Caches.GTS.Marker.Store(marker, func() error {
if prevMarker == nil { if prevMarker == nil {
if _, err := m.db.NewInsert(). if _, err := m.db.NewInsert().
Model(marker). Model(marker).

View File

@ -20,14 +20,15 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"slices"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -51,25 +52,52 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
} }
func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) { func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids)) // Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all media IDs via cache loader callbacks.
media, err := m.state.Caches.GTS.Media.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids { for _, id := range ids {
// Attempt fetch from DB if !load(id) {
attachment, err := m.GetAttachmentByID(ctx, id) uncached = append(uncached, id)
}
}
},
// Uncached media loader function.
func() ([]*gtsmodel.MediaAttachment, error) {
// Preallocate expected length of uncached media attachments.
media := make([]*gtsmodel.MediaAttachment, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := m.db.NewSelect().
Model(&media).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return media, nil
},
)
if err != nil { if err != nil {
log.Errorf(ctx, "error getting attachment %q: %v", id, err) return nil, err
continue
} }
// Append attachment // Reorder the media by their
attachments = append(attachments, attachment) // IDs to ensure in correct order.
} getID := func(m *gtsmodel.MediaAttachment) string { return m.ID }
util.OrderBy(media, ids, getID)
return attachments, nil return media, nil
} }
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) { func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) { return m.state.Caches.GTS.Media.LoadOne(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment var attachment gtsmodel.MediaAttachment
// Not cached! Perform database query // Not cached! Perform database query
@ -82,7 +110,7 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func
} }
func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error { func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error {
return m.state.Caches.GTS.Media().Store(media, func() error { return m.state.Caches.GTS.Media.Store(media, func() error {
_, err := m.db.NewInsert().Model(media).Exec(ctx) _, err := m.db.NewInsert().Model(media).Exec(ctx)
return err return err
}) })
@ -95,7 +123,7 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
return m.state.Caches.GTS.Media().Store(media, func() error { return m.state.Caches.GTS.Media.Store(media, func() error {
_, err := m.db.NewUpdate(). _, err := m.db.NewUpdate().
Model(media). Model(media).
Where("? = ?", bun.Ident("media_attachment.id"), media.ID). Where("? = ?", bun.Ident("media_attachment.id"), media.ID).
@ -119,7 +147,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
} }
// On return, ensure that media with ID is invalidated. // On return, ensure that media with ID is invalidated.
defer m.state.Caches.GTS.Media().Invalidate("ID", id) defer m.state.Caches.GTS.Media.Invalidate("ID", id)
// Delete media attachment in new transaction. // Delete media attachment in new transaction.
err = m.db.RunInTx(ctx, func(tx Tx) error { err = m.db.RunInTx(ctx, func(tx Tx) error {
@ -171,8 +199,12 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
return gtserror.Newf("error selecting status: %w", err) return gtserror.Newf("error selecting status: %w", err)
} }
if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse // Delete all instances of this deleted media ID from status attachments.
len(updatedIDs) != len(status.AttachmentIDs) { updatedIDs := slices.DeleteFunc(status.AttachmentIDs, func(s string) bool {
return s == id
})
if len(updatedIDs) != len(status.AttachmentIDs) {
// Note: this handles not found. // Note: this handles not found.
// //
// Attachments changed, update the status. // Attachments changed, update the status.

View File

@ -20,6 +20,7 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -27,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -36,7 +38,7 @@ type mentionDB struct {
} }
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) { func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) {
mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) { mention, err := m.state.Caches.GTS.Mention.LoadOne("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention var mention gtsmodel.Mention
q := m.db. q := m.db.
@ -63,21 +65,64 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
} }
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) { func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) {
mentions := make([]*gtsmodel.Mention, 0, len(ids)) // Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all mention IDs via cache loader callbacks.
mentions, err := m.state.Caches.GTS.Mention.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids { for _, id := range ids {
// Attempt fetch from DB if !load(id) {
mention, err := m.GetMention(ctx, id) uncached = append(uncached, id)
if err != nil {
log.Errorf(ctx, "error getting mention %q: %v", id, err)
continue
} }
}
},
// Append mention // Uncached mention loader function.
mentions = append(mentions, mention) func() ([]*gtsmodel.Mention, error) {
// Preallocate expected length of uncached mentions.
mentions := make([]*gtsmodel.Mention, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := m.db.NewSelect().
Model(&mentions).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
} }
return mentions, nil return mentions, nil
},
)
if err != nil {
return nil, err
}
// Reorder the mentions by their
// IDs to ensure in correct order.
getID := func(m *gtsmodel.Mention) string { return m.ID }
util.OrderBy(mentions, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return mentions, nil
}
// Populate all loaded mentions, removing those we fail to
// populate (removes needing so many nil checks everywhere).
mentions = slices.DeleteFunc(mentions, func(mention *gtsmodel.Mention) bool {
if err := m.PopulateMention(ctx, mention); err != nil {
log.Errorf(ctx, "error populating mention %s: %v", mention.ID, err)
return true
}
return false
})
return mentions, nil
} }
func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) { func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) {
@ -120,14 +165,14 @@ func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Menti
} }
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error { return m.state.Caches.GTS.Mention.Store(mention, func() error {
_, err := m.db.NewInsert().Model(mention).Exec(ctx) _, err := m.db.NewInsert().Model(mention).Exec(ctx)
return err return err
}) })
} }
func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error { func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
defer m.state.Caches.GTS.Mention().Invalidate("ID", id) defer m.state.Caches.GTS.Mention.Invalidate("ID", id)
// Load mention into cache before attempting a delete, // Load mention into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate // as we need it cached in order to trigger the invalidate

View File

@ -20,6 +20,7 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -37,18 +39,17 @@ type notificationDB struct {
} }
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) { func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { return n.getNotification(
var notif gtsmodel.Notification ctx,
"ID",
q := n.db.NewSelect(). func(notif *gtsmodel.Notification) error {
Model(&notif). return n.db.NewSelect().
Where("? = ?", bun.Ident("notification.id"), id) Model(notif).
if err := q.Scan(ctx); err != nil { Where("? = ?", bun.Ident("id"), id).
return nil, err Scan(ctx)
} },
id,
return &notif, nil )
}, id)
} }
func (n *notificationDB) GetNotification( func (n *notificationDB) GetNotification(
@ -58,42 +59,113 @@ func (n *notificationDB) GetNotification(
originAccountID string, originAccountID string,
statusID string, statusID string,
) (*gtsmodel.Notification, error) { ) (*gtsmodel.Notification, error) {
notif, err := n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { return n.getNotification(
var notif gtsmodel.Notification ctx,
"NotificationType,TargetAccountID,OriginAccountID,StatusID",
q := n.db.NewSelect(). func(notif *gtsmodel.Notification) error {
Model(&notif). return n.db.NewSelect().
Model(notif).
Where("? = ?", bun.Ident("notification_type"), notificationType). Where("? = ?", bun.Ident("notification_type"), notificationType).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID). Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("origin_account_id"), originAccountID). Where("? = ?", bun.Ident("origin_account_id"), originAccountID).
Where("? = ?", bun.Ident("status_id"), statusID) Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx)
},
notificationType, targetAccountID, originAccountID, statusID,
)
}
if err := q.Scan(ctx); err != nil { func (n *notificationDB) getNotification(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Notification) error, keyParts ...any) (*gtsmodel.Notification, error) {
// Fetch notification from cache with loader callback
notif, err := n.state.Caches.GTS.Notification.LoadOne(lookup, func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
// Not cached! Perform database query
if err := dbQuery(&notif); err != nil {
return nil, err return nil, err
} }
return &notif, nil return &notif, nil
}, notificationType, targetAccountID, originAccountID, statusID) }, keyParts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // Only a barebones model was requested.
return notif, nil return notif, nil
} }
// Further populate the notif fields where applicable. if err := n.state.DB.PopulateNotification(ctx, notif); err != nil {
if err := n.PopulateNotification(ctx, notif); err != nil {
return nil, err return nil, err
} }
return notif, nil return notif, nil
} }
func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error) {
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all notif IDs via cache loader callbacks.
notifs, err := n.state.Caches.GTS.Notification.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached notification loader function.
func() ([]*gtsmodel.Notification, error) {
// Preallocate expected length of uncached notifications.
notifs := make([]*gtsmodel.Notification, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := n.db.NewSelect().
Model(&notifs).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return notifs, nil
},
)
if err != nil {
return nil, err
}
// Reorder the notifs by their
// IDs to ensure in correct order.
getID := func(n *gtsmodel.Notification) string { return n.ID }
util.OrderBy(notifs, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return notifs, nil
}
// Populate all loaded notifs, removing those we fail to
// populate (removes needing so many nil checks everywhere).
notifs = slices.DeleteFunc(notifs, func(notif *gtsmodel.Notification) bool {
if err := n.PopulateNotification(ctx, notif); err != nil {
log.Errorf(ctx, "error populating notif %s: %v", notif.ID, err)
return true
}
return false
})
return notifs, nil
}
func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error { func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error {
var ( var (
errs = gtserror.NewMultiError(2) errs gtserror.MultiError
err error err error
) )
@ -211,31 +283,19 @@ func (n *notificationDB) GetAccountNotifications(
} }
} }
notifs := make([]*gtsmodel.Notification, 0, len(notifIDs)) // Fetch notification models by their IDs.
for _, id := range notifIDs { return n.GetNotificationsByIDs(ctx, notifIDs)
// Attempt fetch from DB
notif, err := n.GetNotificationByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching notification %q: %v", id, err)
continue
}
// Append notification
notifs = append(notifs, notif)
}
return notifs, nil
} }
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
return n.state.Caches.GTS.Notification().Store(notif, func() error { return n.state.Caches.GTS.Notification.Store(notif, func() error {
_, err := n.db.NewInsert().Model(notif).Exec(ctx) _, err := n.db.NewInsert().Model(notif).Exec(ctx)
return err return err
}) })
} }
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error { func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error {
defer n.state.Caches.GTS.Notification().Invalidate("ID", id) defer n.state.Caches.GTS.Notification.Invalidate("ID", id)
// Load notif into cache before attempting a delete, // Load notif into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate // as we need it cached in order to trigger the invalidate
@ -288,7 +348,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
defer func() { defer func() {
// Invalidate all IDs on return. // Invalidate all IDs on return.
for _, id := range notifIDs { for _, id := range notifIDs {
n.state.Caches.GTS.Notification().Invalidate("ID", id) n.state.Caches.GTS.Notification.Invalidate("ID", id)
} }
}() }()
@ -326,7 +386,7 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu
defer func() { defer func() {
// Invalidate all IDs on return. // Invalidate all IDs on return.
for _, id := range notifIDs { for _, id := range notifIDs {
n.state.Caches.GTS.Notification().Invalidate("ID", id) n.state.Caches.GTS.Notification.Invalidate("ID", id)
} }
}() }()

View File

@ -20,6 +20,7 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"slices"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -52,7 +54,7 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er
func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) { func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {
// Fetch poll from database cache with loader callback // Fetch poll from database cache with loader callback
poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { poll, err := p.state.Caches.GTS.Poll.LoadOne(lookup, func() (*gtsmodel.Poll, error) {
var poll gtsmodel.Poll var poll gtsmodel.Poll
// Not cached! Perform database query. // Not cached! Perform database query.
@ -140,7 +142,7 @@ func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {
// is non nil and set. // is non nil and set.
poll.CheckVotes() poll.CheckVotes()
return p.state.Caches.GTS.Poll().Store(poll, func() error { return p.state.Caches.GTS.Poll.Store(poll, func() error {
_, err := p.db.NewInsert().Model(poll).Exec(ctx) _, err := p.db.NewInsert().Model(poll).Exec(ctx)
return err return err
}) })
@ -151,7 +153,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st
// is non nil and set. // is non nil and set.
poll.CheckVotes() poll.CheckVotes()
return p.state.Caches.GTS.Poll().Store(poll, func() error { return p.state.Caches.GTS.Poll.Store(poll, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error { return p.db.RunInTx(ctx, func(tx Tx) error {
// Update the status' "updated_at" field. // Update the status' "updated_at" field.
if _, err := tx.NewUpdate(). if _, err := tx.NewUpdate().
@ -184,8 +186,8 @@ func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
} }
// Invalidate poll by ID from cache. // Invalidate poll by ID from cache.
p.state.Caches.GTS.Poll().Invalidate("ID", id) p.state.Caches.GTS.Poll.Invalidate("ID", id)
p.state.Caches.GTS.PollVoteIDs().Invalidate(id) p.state.Caches.GTS.PollVoteIDs.Invalidate(id)
return nil return nil
} }
@ -207,7 +209,7 @@ func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.Poll
func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) { func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {
return p.getPollVote( return p.getPollVote(
ctx, ctx,
"PollID.AccountID", "PollID,AccountID",
func(vote *gtsmodel.PollVote) error { func(vote *gtsmodel.PollVote) error {
return p.db.NewSelect(). return p.db.NewSelect().
Model(vote). Model(vote).
@ -222,7 +224,7 @@ func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID str
func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) { func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {
// Fetch vote from database cache with loader callback // Fetch vote from database cache with loader callback
vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) { vote, err := p.state.Caches.GTS.PollVote.LoadOne(lookup, func() (*gtsmodel.PollVote, error) {
var vote gtsmodel.PollVote var vote gtsmodel.PollVote
// Not cached! Perform database query. // Not cached! Perform database query.
@ -250,7 +252,9 @@ func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*g
} }
func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) { func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) {
voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) {
// Load vote IDs known for given poll ID using loader callback.
voteIDs, err := p.state.Caches.GTS.PollVoteIDs.Load(pollID, func() ([]string, error) {
var voteIDs []string var voteIDs []string
// Vote IDs not in cache, perform DB query! // Vote IDs not in cache, perform DB query!
@ -266,21 +270,62 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P
return nil, err return nil, err
} }
// Preallocate slice of expected length. // Preallocate at-worst possible length.
votes := make([]*gtsmodel.PollVote, 0, len(voteIDs)) uncached := make([]string, 0, len(voteIDs))
// Load all votes from IDs via cache loader callbacks.
votes, err := p.state.Caches.GTS.PollVote.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range voteIDs { for _, id := range voteIDs {
// Fetch poll vote model for this ID. if !load(id) {
vote, err := p.GetPollVoteByID(ctx, id) uncached = append(uncached, id)
if err != nil { }
log.Errorf(ctx, "error getting poll vote %s: %v", id, err) }
continue },
// Uncached poll vote loader function.
func() ([]*gtsmodel.PollVote, error) {
// Preallocate expected length of uncached votes.
votes := make([]*gtsmodel.PollVote, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := p.db.NewSelect().
Model(&votes).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
} }
// Append to return slice. return votes, nil
votes = append(votes, vote) },
)
if err != nil {
return nil, err
} }
// Reorder the poll votes by their
// IDs to ensure in correct order.
getID := func(v *gtsmodel.PollVote) string { return v.ID }
util.OrderBy(votes, voteIDs, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return votes, nil
}
// Populate all loaded votes, removing those we fail to
// populate (removes needing so many nil checks everywhere).
votes = slices.DeleteFunc(votes, func(vote *gtsmodel.PollVote) bool {
if err := p.PopulatePollVote(ctx, vote); err != nil {
log.Errorf(ctx, "error populating vote %s: %v", vote.ID, err)
return true
}
return false
})
return votes, nil return votes, nil
} }
@ -316,7 +361,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote)
} }
func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
return p.state.Caches.GTS.PollVote().Store(vote, func() error { return p.state.Caches.GTS.PollVote.Store(vote, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error { return p.db.RunInTx(ctx, func(tx Tx) error {
// Try insert vote into database. // Try insert vote into database.
if _, err := tx.NewInsert(). if _, err := tx.NewInsert().
@ -416,9 +461,9 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
} }
// Invalidate poll vote and poll entry from caches. // Invalidate poll vote and poll entry from caches.
p.state.Caches.GTS.Poll().Invalidate("ID", pollID) p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID) p.state.Caches.GTS.PollVote.Invalidate("PollID", pollID)
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil return nil
} }
@ -428,7 +473,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
// Slice should only ever be of length // Slice should only ever be of length
// 0 or 1; it's a slice of slices only // 0 or 1; it's a slice of slices only
// because we can't LIMIT deletes to 1. // because we can't LIMIT deletes to 1.
var choicesSl [][]int var choicesSlice [][]int
// Delete vote in poll by account, // Delete vote in poll by account,
// returning the ID + choices of the vote. // returning the ID + choices of the vote.
@ -437,17 +482,19 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
Where("? = ?", bun.Ident("poll_id"), pollID). Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID). Where("? = ?", bun.Ident("account_id"), accountID).
Returning("?", bun.Ident("choices")). Returning("?", bun.Ident("choices")).
Scan(ctx, &choicesSl); err != nil { Scan(ctx, &choicesSlice); err != nil {
// irrecoverable. // irrecoverable.
return err return err
} }
if len(choicesSl) != 1 { if len(choicesSlice) != 1 {
// No poll votes by this // No poll votes by this
// acct on this poll. // acct on this poll.
return nil return nil
} }
choices := choicesSl[0]
// Extract the *actual* choices.
choices := choicesSlice[0]
// Select current poll counts from DB, // Select current poll counts from DB,
// taking minimal columns needed to // taking minimal columns needed to
@ -489,9 +536,9 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
} }
// Invalidate poll vote and poll entry from caches. // Invalidate poll vote and poll entry from caches.
p.state.Caches.GTS.Poll().Invalidate("ID", pollID) p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID) p.state.Caches.GTS.PollVote.Invalidate("PollID,AccountID", pollID, accountID)
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil return nil
} }

View File

@ -194,7 +194,7 @@ func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID strin
} }
func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) {
var followIDs []string var followIDs []string
// Follow IDs not in cache, perform DB query! // Follow IDs not in cache, perform DB query!
@ -209,7 +209,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
} }
func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) {
var followIDs []string var followIDs []string
// Follow IDs not in cache, perform DB query! // Follow IDs not in cache, perform DB query!
@ -224,7 +224,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
} }
func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) {
var followIDs []string var followIDs []string
// Follow IDs not in cache, perform DB query! // Follow IDs not in cache, perform DB query!
@ -239,7 +239,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
} }
func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) {
var followIDs []string var followIDs []string
// Follow IDs not in cache, perform DB query! // Follow IDs not in cache, perform DB query!
@ -254,7 +254,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
} }
func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) {
var followReqIDs []string var followReqIDs []string
// Follow request IDs not in cache, perform DB query! // Follow request IDs not in cache, perform DB query!
@ -269,7 +269,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
} }
func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) {
var followReqIDs []string var followReqIDs []string
// Follow request IDs not in cache, perform DB query! // Follow request IDs not in cache, perform DB query!
@ -284,7 +284,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
} }
func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) {
var blockIDs []string var blockIDs []string
// Block IDs not in cache, perform DB query! // Block IDs not in cache, perform DB query!

View File

@ -20,12 +20,14 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -86,7 +88,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod
func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) { func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
return r.getBlock( return r.getBlock(
ctx, ctx,
"AccountID.TargetAccountID", "AccountID,TargetAccountID",
func(block *gtsmodel.Block) error { func(block *gtsmodel.Block) error {
return r.db.NewSelect().Model(block). return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
@ -99,27 +101,68 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t
} }
func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) { func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) {
// Preallocate slice of expected length. // Preallocate at-worst possible length.
blocks := make([]*gtsmodel.Block, 0, len(ids)) uncached := make([]string, 0, len(ids))
// Load all blocks IDs via cache loader callbacks.
blocks, err := r.state.Caches.GTS.Block.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids { for _, id := range ids {
// Fetch block model for this ID. if !load(id) {
block, err := r.GetBlockByID(ctx, id) uncached = append(uncached, id)
if err != nil { }
log.Errorf(ctx, "error getting block %q: %v", id, err) }
continue },
// Uncached block loader function.
func() ([]*gtsmodel.Block, error) {
// Preallocate expected length of uncached blocks.
blocks := make([]*gtsmodel.Block, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := r.db.NewSelect().
Model(&blocks).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
} }
// Append to return slice. return blocks, nil
blocks = append(blocks, block) },
)
if err != nil {
return nil, err
} }
// Reorder the blocks by their
// IDs to ensure in correct order.
getID := func(b *gtsmodel.Block) string { return b.ID }
util.OrderBy(blocks, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return blocks, nil
}
// Populate all loaded blocks, removing those we fail to
// populate (removes needing so many nil checks everywhere).
blocks = slices.DeleteFunc(blocks, func(block *gtsmodel.Block) bool {
if err := r.PopulateBlock(ctx, block); err != nil {
log.Errorf(ctx, "error populating block %s: %v", block.ID, err)
return true
}
return false
})
return blocks, nil return blocks, nil
} }
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) { func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
// Fetch block from cache with loader callback // Fetch block from cache with loader callback
block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { block, err := r.state.Caches.GTS.Block.LoadOne(lookup, func() (*gtsmodel.Block, error) {
var block gtsmodel.Block var block gtsmodel.Block
// Not cached! Perform database query // Not cached! Perform database query
@ -148,8 +191,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error { func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error {
var ( var (
errs gtserror.MultiError
err error err error
errs = gtserror.NewMultiError(2)
) )
if block.Account == nil { if block.Account == nil {
@ -178,7 +221,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc
} }
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
return r.state.Caches.GTS.Block().Store(block, func() error { return r.state.Caches.GTS.Block.Store(block, func() error {
_, err := r.db.NewInsert().Model(block).Exec(ctx) _, err := r.db.NewInsert().Model(block).Exec(ctx)
return err return err
}) })
@ -198,7 +241,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
} }
// Drop this now-cached block on return after delete. // Drop this now-cached block on return after delete.
defer r.state.Caches.GTS.Block().Invalidate("ID", id) defer r.state.Caches.GTS.Block.Invalidate("ID", id)
// Finally delete block from DB. // Finally delete block from DB.
_, err = r.db.NewDelete(). _, err = r.db.NewDelete().
@ -222,7 +265,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error
} }
// Drop this now-cached block on return after delete. // Drop this now-cached block on return after delete.
defer r.state.Caches.GTS.Block().Invalidate("URI", uri) defer r.state.Caches.GTS.Block.Invalidate("URI", uri)
// Finally delete block from DB. // Finally delete block from DB.
_, err = r.db.NewDelete(). _, err = r.db.NewDelete().
@ -251,22 +294,20 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
defer func() { defer func() {
// Invalidate all account's incoming / outoing blocks on return. // Invalidate all account's incoming / outoing blocks on return.
r.state.Caches.GTS.Block().Invalidate("AccountID", accountID) r.state.Caches.GTS.Block.Invalidate("AccountID", accountID)
r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID) r.state.Caches.GTS.Block.Invalidate("TargetAccountID", accountID)
}() }()
// Load all blocks into cache, this *really* isn't great // Load all blocks into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all // but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility). // related caches correctly (e.g. visibility).
for _, id := range blockIDs { _, err := r.GetAccountBlocks(ctx, accountID, nil)
_, err := r.GetBlockByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err return err
} }
}
// Finally delete all from DB. // Finally delete all from DB.
_, err := r.db.NewDelete(). _, err = r.db.NewDelete().
Table("blocks"). Table("blocks").
Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).
Exec(ctx) Exec(ctx)

View File

@ -21,6 +21,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"slices"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -62,7 +64,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo
func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
return r.getFollow( return r.getFollow(
ctx, ctx,
"AccountID.TargetAccountID", "AccountID,TargetAccountID",
func(follow *gtsmodel.Follow) error { func(follow *gtsmodel.Follow) error {
return r.db.NewSelect(). return r.db.NewSelect().
Model(follow). Model(follow).
@ -76,21 +78,62 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string,
} }
func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
// Preallocate slice of expected length. // Preallocate at-worst possible length.
follows := make([]*gtsmodel.Follow, 0, len(ids)) uncached := make([]string, 0, len(ids))
// Load all follow IDs via cache loader callbacks.
follows, err := r.state.Caches.GTS.Follow.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids { for _, id := range ids {
// Fetch follow model for this ID. if !load(id) {
follow, err := r.GetFollowByID(ctx, id) uncached = append(uncached, id)
if err != nil { }
log.Errorf(ctx, "error getting follow %q: %v", id, err) }
continue },
// Uncached follow loader function.
func() ([]*gtsmodel.Follow, error) {
// Preallocate expected length of uncached follows.
follows := make([]*gtsmodel.Follow, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := r.db.NewSelect().
Model(&follows).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
} }
// Append to return slice. return follows, nil
follows = append(follows, follow) },
)
if err != nil {
return nil, err
} }
// Reorder the follows by their
// IDs to ensure in correct order.
getID := func(f *gtsmodel.Follow) string { return f.ID }
util.OrderBy(follows, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return follows, nil
}
// Populate all loaded follows, removing those we fail to
// populate (removes needing so many nil checks everywhere).
follows = slices.DeleteFunc(follows, func(follow *gtsmodel.Follow) bool {
if err := r.PopulateFollow(ctx, follow); err != nil {
log.Errorf(ctx, "error populating follow %s: %v", follow.ID, err)
return true
}
return false
})
return follows, nil return follows, nil
} }
@ -130,7 +173,7 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 strin
func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) { func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
// Fetch follow from database cache with loader callback // Fetch follow from database cache with loader callback
follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) { follow, err := r.state.Caches.GTS.Follow.LoadOne(lookup, func() (*gtsmodel.Follow, error) {
var follow gtsmodel.Follow var follow gtsmodel.Follow
// Not cached! Perform database query // Not cached! Perform database query
@ -189,7 +232,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
} }
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
return r.state.Caches.GTS.Follow().Store(follow, func() error { return r.state.Caches.GTS.Follow.Store(follow, func() error {
_, err := r.db.NewInsert().Model(follow).Exec(ctx) _, err := r.db.NewInsert().Model(follow).Exec(ctx)
return err return err
}) })
@ -202,7 +245,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
return r.state.Caches.GTS.Follow().Store(follow, func() error { return r.state.Caches.GTS.Follow.Store(follow, func() error {
if _, err := r.db.NewUpdate(). if _, err := r.db.NewUpdate().
Model(follow). Model(follow).
Where("? = ?", bun.Ident("follow.id"), follow.ID). Where("? = ?", bun.Ident("follow.id"), follow.ID).
@ -250,7 +293,7 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin
} }
// Drop this now-cached follow on return after delete. // Drop this now-cached follow on return after delete.
defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) defer r.state.Caches.GTS.Follow.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)
// Finally delete follow from DB. // Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID) return r.deleteFollow(ctx, follow.ID)
@ -270,7 +313,7 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error
} }
// Drop this now-cached follow on return after delete. // Drop this now-cached follow on return after delete.
defer r.state.Caches.GTS.Follow().Invalidate("ID", id) defer r.state.Caches.GTS.Follow.Invalidate("ID", id)
// Finally delete follow from DB. // Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID) return r.deleteFollow(ctx, follow.ID)
@ -290,7 +333,7 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
} }
// Drop this now-cached follow on return after delete. // Drop this now-cached follow on return after delete.
defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) defer r.state.Caches.GTS.Follow.Invalidate("URI", uri)
// Finally delete follow from DB. // Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID) return r.deleteFollow(ctx, follow.ID)
@ -316,22 +359,30 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
defer func() { defer func() {
// Invalidate all account's incoming / outoing follows on return. // Invalidate all account's incoming / outoing follows on return.
r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID) r.state.Caches.GTS.Follow.Invalidate("AccountID", accountID)
r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID) r.state.Caches.GTS.Follow.Invalidate("TargetAccountID", accountID)
}() }()
// Load all follows into cache, this *really* isn't great // Load all follows into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all // but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility). // related caches correctly (e.g. visibility).
for _, id := range followIDs { _, err := r.GetAccountFollows(ctx, accountID, nil)
follow, err := r.GetFollowByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err return err
} }
// Delete each follow from DB. // Delete all follows from DB.
if err := r.deleteFollow(ctx, follow.ID); err != nil && _, err = r.db.NewDelete().
!errors.Is(err, db.ErrNoEntries) { Table("follows").
Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)).
Exec(ctx)
if err != nil {
return err
}
for _, id := range followIDs {
// Finally, delete all list entries associated with each follow ID.
if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
return err return err
} }
} }

View File

@ -20,6 +20,7 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"slices"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -27,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -61,7 +63,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string)
func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest( return r.getFollowRequest(
ctx, ctx,
"AccountID.TargetAccountID", "AccountID,TargetAccountID",
func(followReq *gtsmodel.FollowRequest) error { func(followReq *gtsmodel.FollowRequest) error {
return r.db.NewSelect(). return r.db.NewSelect().
Model(followReq). Model(followReq).
@ -75,22 +77,63 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s
} }
func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
// Preallocate slice of expected length. // Preallocate at-worst possible length.
followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids)) uncached := make([]string, 0, len(ids))
// Load all follow IDs via cache loader callbacks.
follows, err := r.state.Caches.GTS.FollowRequest.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids { for _, id := range ids {
// Fetch follow request model for this ID. if !load(id) {
followReq, err := r.GetFollowRequestByID(ctx, id) uncached = append(uncached, id)
}
}
},
// Uncached follow req loader function.
func() ([]*gtsmodel.FollowRequest, error) {
// Preallocate expected length of uncached followReqs.
follows := make([]*gtsmodel.FollowRequest, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := r.db.NewSelect().
Model(&follows).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return follows, nil
},
)
if err != nil { if err != nil {
log.Errorf(ctx, "error getting follow request %q: %v", id, err) return nil, err
continue
} }
// Append to return slice. // Reorder the requests by their
followReqs = append(followReqs, followReq) // IDs to ensure in correct order.
getID := func(f *gtsmodel.FollowRequest) string { return f.ID }
util.OrderBy(follows, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return follows, nil
} }
return followReqs, nil // Populate all loaded followreqs, removing those we fail to
// populate (removes needing so many nil checks everywhere).
follows = slices.DeleteFunc(follows, func(follow *gtsmodel.FollowRequest) bool {
if err := r.PopulateFollowRequest(ctx, follow); err != nil {
log.Errorf(ctx, "error populating follow request %s: %v", follow.ID, err)
return true
}
return false
})
return follows, nil
} }
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) { func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
@ -107,7 +150,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID
func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) { func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
// Fetch follow request from database cache with loader callback // Fetch follow request from database cache with loader callback
followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) { followReq, err := r.state.Caches.GTS.FollowRequest.LoadOne(lookup, func() (*gtsmodel.FollowRequest, error) {
var followReq gtsmodel.FollowRequest var followReq gtsmodel.FollowRequest
// Not cached! Perform database query // Not cached! Perform database query
@ -166,7 +209,7 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm
} }
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { return r.state.Caches.GTS.FollowRequest.Store(follow, func() error {
_, err := r.db.NewInsert().Model(follow).Exec(ctx) _, err := r.db.NewInsert().Model(follow).Exec(ctx)
return err return err
}) })
@ -179,7 +222,7 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error { return r.state.Caches.GTS.FollowRequest.Store(followRequest, func() error {
if _, err := r.db.NewUpdate(). if _, err := r.db.NewUpdate().
Model(followRequest). Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
@ -212,7 +255,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
Notify: followReq.Notify, Notify: followReq.Notify,
} }
if err := r.state.Caches.GTS.Follow().Store(follow, func() error { if err := r.state.Caches.GTS.Follow.Store(follow, func() error {
// If the follow already exists, just // If the follow already exists, just
// replace the URI with the new one. // replace the URI with the new one.
_, err := r.db. _, err := r.db.
@ -274,7 +317,7 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI
} }
// Drop this now-cached follow request on return after delete. // Drop this now-cached follow request on return after delete.
defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) defer r.state.Caches.GTS.FollowRequest.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)
// Finally delete followreq from DB. // Finally delete followreq from DB.
_, err = r.db.NewDelete(). _, err = r.db.NewDelete().
@ -298,7 +341,7 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)
} }
// Drop this now-cached follow request on return after delete. // Drop this now-cached follow request on return after delete.
defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) defer r.state.Caches.GTS.FollowRequest.Invalidate("ID", id)
// Finally delete followreq from DB. // Finally delete followreq from DB.
_, err = r.db.NewDelete(). _, err = r.db.NewDelete().
@ -322,7 +365,7 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin
} }
// Drop this now-cached follow request on return after delete. // Drop this now-cached follow request on return after delete.
defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) defer r.state.Caches.GTS.FollowRequest.Invalidate("URI", uri)
// Finally delete followreq from DB. // Finally delete followreq from DB.
_, err = r.db.NewDelete(). _, err = r.db.NewDelete().
@ -352,22 +395,20 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
defer func() { defer func() {
// Invalidate all account's incoming / outoing follow requests on return. // Invalidate all account's incoming / outoing follow requests on return.
r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID) r.state.Caches.GTS.FollowRequest.Invalidate("AccountID", accountID)
r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID) r.state.Caches.GTS.FollowRequest.Invalidate("TargetAccountID", accountID)
}() }()
// Load all followreqs into cache, this *really* isn't // Load all followreqs into cache, this *really* isn't
// great but it is the only way we can ensure we invalidate // great but it is the only way we can ensure we invalidate
// all related caches correctly (e.g. visibility). // all related caches correctly (e.g. visibility).
for _, id := range followReqIDs { _, err := r.GetAccountFollowRequests(ctx, accountID, nil)
_, err := r.GetFollowRequestByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err return err
} }
}
// Finally delete all from DB. // Finally delete all from DB.
_, err := r.db.NewDelete(). _, err = r.db.NewDelete().
Table("follow_requests"). Table("follow_requests").
Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)).
Exec(ctx) Exec(ctx)

View File

@ -30,7 +30,7 @@ import (
func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) { func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) {
return r.getNote( return r.getNote(
ctx, ctx,
"AccountID.TargetAccountID", "AccountID,TargetAccountID",
func(note *gtsmodel.AccountNote) error { func(note *gtsmodel.AccountNote) error {
return r.db.NewSelect().Model(note). return r.db.NewSelect().Model(note).
Where("? = ?", bun.Ident("account_id"), sourceAccountID). Where("? = ?", bun.Ident("account_id"), sourceAccountID).
@ -44,7 +44,7 @@ func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, ta
func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.AccountNote) error, keyParts ...any) (*gtsmodel.AccountNote, error) { func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.AccountNote) error, keyParts ...any) (*gtsmodel.AccountNote, error) {
// Fetch note from cache with loader callback // Fetch note from cache with loader callback
note, err := r.state.Caches.GTS.AccountNote().Load(lookup, func() (*gtsmodel.AccountNote, error) { note, err := r.state.Caches.GTS.AccountNote.LoadOne(lookup, func() (*gtsmodel.AccountNote, error) {
var note gtsmodel.AccountNote var note gtsmodel.AccountNote
// Not cached! Perform database query // Not cached! Perform database query
@ -105,7 +105,7 @@ func (r *relationshipDB) PopulateNote(ctx context.Context, note *gtsmodel.Accoun
func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote) error { func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote) error {
note.UpdatedAt = time.Now() note.UpdatedAt = time.Now()
return r.state.Caches.GTS.AccountNote().Store(note, func() error { return r.state.Caches.GTS.AccountNote.Store(note, func() error {
_, err := r.db. _, err := r.db.
NewInsert(). NewInsert().
Model(note). Model(note).

View File

@ -120,7 +120,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str
func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) { func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) {
// Fetch report from database cache with loader callback // Fetch report from database cache with loader callback
report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) { report, err := r.state.Caches.GTS.Report.LoadOne(lookup, func() (*gtsmodel.Report, error) {
var report gtsmodel.Report var report gtsmodel.Report
// Not cached! Perform database query // Not cached! Perform database query
@ -215,7 +215,7 @@ func (r *reportDB) PopulateReport(ctx context.Context, report *gtsmodel.Report)
} }
func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error { func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error {
return r.state.Caches.GTS.Report().Store(report, func() error { return r.state.Caches.GTS.Report.Store(report, func() error {
_, err := r.db.NewInsert().Model(report).Exec(ctx) _, err := r.db.NewInsert().Model(report).Exec(ctx)
return err return err
}) })
@ -237,12 +237,12 @@ func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, co
return nil, err return nil, err
} }
r.state.Caches.GTS.Report().Invalidate("ID", report.ID) r.state.Caches.GTS.Report.Invalidate("ID", report.ID)
return report, nil return report, nil
} }
func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error { func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error {
defer r.state.Caches.GTS.Report().Invalidate("ID", id) defer r.state.Caches.GTS.Report.Invalidate("ID", id)
// Load status into cache before attempting a delete, // Load status into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate // as we need it cached in order to trigger the invalidate

View File

@ -125,7 +125,7 @@ func (r *ruleDB) PutRule(ctx context.Context, rule *gtsmodel.Rule) error {
} }
// invalidate cached local instance response, so it gets updated with the new rules // invalidate cached local instance response, so it gets updated with the new rules
r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost()) r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())
return nil return nil
} }
@ -143,7 +143,7 @@ func (r *ruleDB) UpdateRule(ctx context.Context, rule *gtsmodel.Rule) (*gtsmodel
} }
// invalidate cached local instance response, so it gets updated with the new rules // invalidate cached local instance response, so it gets updated with the new rules
r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost()) r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())
return rule, nil return rule, nil
} }

View File

@ -20,6 +20,7 @@ package bundb
import ( import (
"context" "context"
"errors" "errors"
"slices"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -48,20 +50,62 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
} }
func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) { func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) {
statuses := make([]*gtsmodel.Status, 0, len(ids)) // Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all status IDs via cache loader callbacks.
statuses, err := s.state.Caches.GTS.Status.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids { for _, id := range ids {
// Attempt to fetch status from DB. if !load(id) {
status, err := s.GetStatusByID(ctx, id) uncached = append(uncached, id)
if err != nil { }
log.Errorf(ctx, "error getting status %q: %v", id, err) }
continue },
// Uncached statuses loader function.
func() ([]*gtsmodel.Status, error) {
// Preallocate expected length of uncached statuses.
statuses := make([]*gtsmodel.Status, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) status IDs.
if err := s.db.NewSelect().
Model(&statuses).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
} }
// Append status to return slice. return statuses, nil
statuses = append(statuses, status) },
)
if err != nil {
return nil, err
} }
// Reorder the statuses by their
// IDs to ensure in correct order.
getID := func(s *gtsmodel.Status) string { return s.ID }
util.OrderBy(statuses, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return statuses, nil
}
// Populate all loaded statuses, removing those we fail to
// populate (removes needing so many nil checks everywhere).
statuses = slices.DeleteFunc(statuses, func(status *gtsmodel.Status) bool {
if err := s.PopulateStatus(ctx, status); err != nil {
log.Errorf(ctx, "error populating status %s: %v", status.ID, err)
return true
}
return false
})
return statuses, nil return statuses, nil
} }
@ -101,7 +145,7 @@ func (s *statusDB) GetStatusByPollID(ctx context.Context, pollID string) (*gtsmo
func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) { func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
"BoostOfID.AccountID", "BoostOfID,AccountID",
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.db.NewSelect().Model(status). return s.db.NewSelect().Model(status).
Where("status.boost_of_id = ?", boostOfID). Where("status.boost_of_id = ?", boostOfID).
@ -120,7 +164,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou
func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) { func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) {
// Fetch status from database cache with loader callback // Fetch status from database cache with loader callback
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { status, err := s.state.Caches.GTS.Status.LoadOne(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status var status gtsmodel.Status
// Not cached! Perform database query. // Not cached! Perform database query.
@ -282,7 +326,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
} }
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error {
return s.state.Caches.GTS.Status().Store(status, func() error { return s.state.Caches.GTS.Status.Store(status, func() error {
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
@ -366,7 +410,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
return s.state.Caches.GTS.Status().Store(status, func() error { return s.state.Caches.GTS.Status.Store(status, func() error {
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
@ -463,7 +507,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
} }
// On return ensure status invalidated from cache. // On return ensure status invalidated from cache.
defer s.state.Caches.GTS.Status().Invalidate("ID", id) defer s.state.Caches.GTS.Status.Invalidate("ID", id)
return s.db.RunInTx(ctx, func(tx Tx) error { return s.db.RunInTx(ctx, func(tx Tx) error {
// delete links between this status and any emojis it uses // delete links between this status and any emojis it uses
@ -585,7 +629,7 @@ func (s *statusDB) CountStatusReplies(ctx context.Context, statusID string) (int
} }
func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) { func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) {
return s.state.Caches.GTS.InReplyToIDs().Load(statusID, func() ([]string, error) { return s.state.Caches.GTS.InReplyToIDs.Load(statusID, func() ([]string, error) {
var statusIDs []string var statusIDs []string
// Status reply IDs not in cache, perform DB query! // Status reply IDs not in cache, perform DB query!
@ -629,7 +673,7 @@ func (s *statusDB) CountStatusBoosts(ctx context.Context, statusID string) (int,
} }
func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) { func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) {
return s.state.Caches.GTS.BoostOfIDs().Load(statusID, func() ([]string, error) { return s.state.Caches.GTS.BoostOfIDs.Load(statusID, func() ([]string, error) {
var statusIDs []string var statusIDs []string
// Status boost IDs not in cache, perform DB query! // Status boost IDs not in cache, perform DB query!

View File

@ -22,6 +22,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -40,7 +42,7 @@ type statusFaveDB struct {
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) { func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) {
return s.getStatusFave( return s.getStatusFave(
ctx, ctx,
"AccountID.StatusID", "AccountID,StatusID",
func(fave *gtsmodel.StatusFave) error { func(fave *gtsmodel.StatusFave) error {
return s.db. return s.db.
NewSelect(). NewSelect().
@ -77,7 +79,7 @@ func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmo
func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) { func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
// Fetch status fave from database cache with loader callback // Fetch status fave from database cache with loader callback
fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) { fave, err := s.state.Caches.GTS.StatusFave.LoadOne(lookup, func() (*gtsmodel.StatusFave, error) {
var fave gtsmodel.StatusFave var fave gtsmodel.StatusFave
// Not cached! Perform database query. // Not cached! Perform database query.
@ -111,19 +113,62 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
return nil, err return nil, err
} }
// Preallocate a slice of expected status fave capacity. // Preallocate at-worst possible length.
faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs)) uncached := make([]string, 0, len(faveIDs))
// Load all fave IDs via cache loader callbacks.
faves, err := s.state.Caches.GTS.StatusFave.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range faveIDs { for _, id := range faveIDs {
// Fetch status fave model for each ID. if !load(id) {
fave, err := s.GetStatusFaveByID(ctx, id) uncached = append(uncached, id)
}
}
},
// Uncached status faves loader function.
func() ([]*gtsmodel.StatusFave, error) {
// Preallocate expected length of uncached faves.
faves := make([]*gtsmodel.StatusFave, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) fave IDs.
if err := s.db.NewSelect().
Model(&faves).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return faves, nil
},
)
if err != nil { if err != nil {
log.Errorf(ctx, "error getting status fave %q: %v", id, err) return nil, err
continue
} }
faves = append(faves, fave)
// Reorder the statuses by their
// IDs to ensure in correct order.
getID := func(f *gtsmodel.StatusFave) string { return f.ID }
util.OrderBy(faves, faveIDs, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return faves, nil
} }
// Populate all loaded faves, removing those we fail to
// populate (removes needing so many nil checks everywhere).
faves = slices.DeleteFunc(faves, func(fave *gtsmodel.StatusFave) bool {
if err := s.PopulateStatusFave(ctx, fave); err != nil {
log.Errorf(ctx, "error populating fave %s: %v", fave.ID, err)
return true
}
return false
})
return faves, nil return faves, nil
} }
@ -141,7 +186,7 @@ func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (i
} }
func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) { func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) {
return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) { return s.state.Caches.GTS.StatusFaveIDs.Load(statusID, func() ([]string, error) {
var faveIDs []string var faveIDs []string
// Status fave IDs not in cache, perform DB query! // Status fave IDs not in cache, perform DB query!
@ -201,7 +246,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
} }
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error { func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error {
return s.state.Caches.GTS.StatusFave().Store(fave, func() error { return s.state.Caches.GTS.StatusFave.Store(fave, func() error {
_, err := s.db. _, err := s.db.
NewInsert(). NewInsert().
Model(fave). Model(fave).
@ -230,10 +275,10 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) erro
if statusID != "" { if statusID != "" {
// Invalidate any cached status faves for this status. // Invalidate any cached status faves for this status.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id) s.state.Caches.GTS.StatusFave.Invalidate("ID", id)
// Invalidate any cached status fave IDs for this status. // Invalidate any cached status fave IDs for this status.
s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)
} }
return nil return nil
@ -270,17 +315,15 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return err return err
} }
// Collate (deduplicating) status IDs. // Deduplicate determined status IDs.
statusIDs = collate(func(i int) string { statusIDs = util.Deduplicate(statusIDs)
return statusIDs[i]
}, len(statusIDs))
for _, id := range statusIDs { for _, id := range statusIDs {
// Invalidate any cached status faves for this status. // Invalidate any cached status faves for this status.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id) s.state.Caches.GTS.StatusFave.Invalidate("ID", id)
// Invalidate any cached status fave IDs for this status. // Invalidate any cached status fave IDs for this status.
s.state.Caches.GTS.StatusFaveIDs().Invalidate(id) s.state.Caches.GTS.StatusFaveIDs.Invalidate(id)
} }
return nil return nil
@ -296,10 +339,10 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID
} }
// Invalidate any cached status faves for this status. // Invalidate any cached status faves for this status.
s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID) s.state.Caches.GTS.StatusFave.Invalidate("ID", statusID)
// Invalidate any cached status fave IDs for this status. // Invalidate any cached status fave IDs for this status.
s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)
return nil return nil
} }

View File

@ -22,21 +22,21 @@ import (
"strings" "strings"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type tagDB struct { type tagDB struct {
conn *DB db *DB
state *state.State state *state.State
} }
func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { func (t *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) { return t.state.Caches.GTS.Tag.LoadOne("ID", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag var tag gtsmodel.Tag
q := m.conn. q := t.db.
NewSelect(). NewSelect().
Model(&tag). Model(&tag).
Where("? = ?", bun.Ident("tag.id"), id) Where("? = ?", bun.Ident("tag.id"), id)
@ -49,15 +49,15 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
}, id) }, id)
} }
func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) { func (t *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {
// Normalize 'name' string. // Normalize 'name' string.
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
name = strings.ToLower(name) name = strings.ToLower(name)
return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) { return t.state.Caches.GTS.Tag.LoadOne("Name", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag var tag gtsmodel.Tag
q := m.conn. q := t.db.
NewSelect(). NewSelect().
Model(&tag). Model(&tag).
Where("? = ?", bun.Ident("tag.name"), name) Where("? = ?", bun.Ident("tag.name"), name)
@ -70,25 +70,52 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e
}, name) }, name)
} }
func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { func (t *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) {
tags := make([]*gtsmodel.Tag, 0, len(ids)) // Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all tag IDs via cache loader callbacks.
tags, err := t.state.Caches.GTS.Tag.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids { for _, id := range ids {
// Attempt fetch from DB if !load(id) {
tag, err := m.GetTag(ctx, id) uncached = append(uncached, id)
if err != nil { }
log.Errorf(ctx, "error getting tag %q: %v", id, err) }
continue },
// Uncached tag loader function.
func() ([]*gtsmodel.Tag, error) {
// Preallocate expected length of uncached tags.
tags := make([]*gtsmodel.Tag, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := t.db.NewSelect().
Model(&tags).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
} }
// Append tag return tags, nil
tags = append(tags, tag) },
)
if err != nil {
return nil, err
} }
// Reorder the tags by their
// IDs to ensure in correct order.
getID := func(t *gtsmodel.Tag) string { return t.ID }
util.OrderBy(tags, ids, getID)
return tags, nil return tags, nil
} }
func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
// Normalize 'name' string before it enters // Normalize 'name' string before it enters
// the db, without changing tag we were given. // the db, without changing tag we were given.
// //
@ -101,8 +128,8 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
t2.Name = strings.ToLower(t2.Name) t2.Name = strings.ToLower(t2.Name)
// Insert the copy. // Insert the copy.
if err := m.state.Caches.GTS.Tag().Store(t2, func() error { if err := t.state.Caches.GTS.Tag.Store(t2, func() error {
_, err := m.conn.NewInsert().Model(t2).Exec(ctx) _, err := t.db.NewInsert().Model(t2).Exec(ctx)
return err return err
}); err != nil { }); err != nil {
return err // err already processed return err // err already processed

View File

@ -42,7 +42,7 @@ func (t *threadDB) PutThread(ctx context.Context, thread *gtsmodel.Thread) error
} }
func (t *threadDB) GetThreadMute(ctx context.Context, id string) (*gtsmodel.ThreadMute, error) { func (t *threadDB) GetThreadMute(ctx context.Context, id string) (*gtsmodel.ThreadMute, error) {
return t.state.Caches.GTS.ThreadMute().Load("ID", func() (*gtsmodel.ThreadMute, error) { return t.state.Caches.GTS.ThreadMute.LoadOne("ID", func() (*gtsmodel.ThreadMute, error) {
var threadMute gtsmodel.ThreadMute var threadMute gtsmodel.ThreadMute
q := t.db. q := t.db.
@ -63,7 +63,7 @@ func (t *threadDB) GetThreadMutedByAccount(
threadID string, threadID string,
accountID string, accountID string,
) (*gtsmodel.ThreadMute, error) { ) (*gtsmodel.ThreadMute, error) {
return t.state.Caches.GTS.ThreadMute().Load("ThreadID.AccountID", func() (*gtsmodel.ThreadMute, error) { return t.state.Caches.GTS.ThreadMute.LoadOne("ThreadID,AccountID", func() (*gtsmodel.ThreadMute, error) {
var threadMute gtsmodel.ThreadMute var threadMute gtsmodel.ThreadMute
q := t.db. q := t.db.
@ -98,7 +98,7 @@ func (t *threadDB) IsThreadMutedByAccount(
} }
func (t *threadDB) PutThreadMute(ctx context.Context, threadMute *gtsmodel.ThreadMute) error { func (t *threadDB) PutThreadMute(ctx context.Context, threadMute *gtsmodel.ThreadMute) error {
return t.state.Caches.GTS.ThreadMute().Store(threadMute, func() error { return t.state.Caches.GTS.ThreadMute.Store(threadMute, func() error {
_, err := t.db.NewInsert().Model(threadMute).Exec(ctx) _, err := t.db.NewInsert().Model(threadMute).Exec(ctx)
return err return err
}) })
@ -112,6 +112,6 @@ func (t *threadDB) DeleteThreadMute(ctx context.Context, id string) error {
return err return err
} }
t.state.Caches.GTS.ThreadMute().Invalidate("ID", id) t.state.Caches.GTS.ThreadMute.Invalidate("ID", id)
return nil return nil
} }

View File

@ -29,7 +29,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -155,20 +154,8 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
} }
} }
statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) // Return status IDs loaded from cache + db.
for _, id := range statusIDs { return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
} }
func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
@ -256,20 +243,8 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
} }
} }
statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) // Return status IDs loaded from cache + db.
for _, id := range statusIDs { return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
} }
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
@ -323,18 +298,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
} }
}) })
statuses := make([]*gtsmodel.Status, 0, len(faves)) // Convert fave IDs to status IDs.
statusIDs := make([]string, len(faves))
for _, fave := range faves { for i, fave := range faves {
// Fetch status from db for corresponding favourite statusIDs[i] = fave.StatusID
status, err := t.state.DB.GetStatusByID(ctx, fave.StatusID)
if err != nil {
log.Errorf(ctx, "error fetching status for fave %q: %v", fave.ID, err)
continue
} }
// Append status to slice statuses, err := t.state.DB.GetStatusesByIDs(ctx, statusIDs)
statuses = append(statuses, status) if err != nil {
return nil, "", "", err
} }
nextMaxID := faves[len(faves)-1].ID nextMaxID := faves[len(faves)-1].ID
@ -453,20 +425,8 @@ func (t *timelineDB) GetListTimeline(
} }
} }
statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) // Return status IDs loaded from cache + db.
for _, id := range statusIDs { return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
} }
func (t *timelineDB) GetTagTimeline( func (t *timelineDB) GetTagTimeline(
@ -561,18 +521,6 @@ func (t *timelineDB) GetTagTimeline(
} }
} }
statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) // Return status IDs loaded from cache + db.
for _, id := range statusIDs { return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
} }

View File

@ -32,7 +32,7 @@ type tombstoneDB struct {
} }
func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) { func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) {
return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) { return t.state.Caches.GTS.Tombstone.LoadOne("URI", func() (*gtsmodel.Tombstone, error) {
var tomb gtsmodel.Tombstone var tomb gtsmodel.Tombstone
q := t.db. q := t.db.
@ -57,7 +57,7 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b
} }
func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error { func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error {
return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error { return t.state.Caches.GTS.Tombstone.Store(tombstone, func() error {
_, err := t.db. _, err := t.db.
NewInsert(). NewInsert().
Model(tombstone). Model(tombstone).
@ -67,7 +67,7 @@ func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tomb
} }
func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error { func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error {
defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id) defer t.state.Caches.GTS.Tombstone.Invalidate("ID", id)
// Delete tombstone from DB. // Delete tombstone from DB.
_, err := t.db.NewDelete(). _, err := t.db.NewDelete().

View File

@ -116,7 +116,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (
func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) { func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) {
// Fetch user from database cache with loader callback. // Fetch user from database cache with loader callback.
user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) { user, err := u.state.Caches.GTS.User.LoadOne(lookup, func() (*gtsmodel.User, error) {
var user gtsmodel.User var user gtsmodel.User
// Not cached! perform database query. // Not cached! perform database query.
@ -179,7 +179,7 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
} }
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error {
return u.state.Caches.GTS.User().Store(user, func() error { return u.state.Caches.GTS.User.Store(user, func() error {
_, err := u.db. _, err := u.db.
NewInsert(). NewInsert().
Model(user). Model(user).
@ -197,7 +197,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
return u.state.Caches.GTS.User().Store(user, func() error { return u.state.Caches.GTS.User.Store(user, func() error {
_, err := u.db. _, err := u.db.
NewUpdate(). NewUpdate().
Model(user). Model(user).
@ -209,7 +209,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
} }
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error { func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error {
defer u.state.Caches.GTS.User().Invalidate("ID", userID) defer u.state.Caches.GTS.User.Invalidate("ID", userID)
// Load user into cache before attempting a delete, // Load user into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate // as we need it cached in order to trigger the invalidate

View File

@ -27,6 +27,9 @@ type List interface {
// GetListByID gets one list with the given id. // GetListByID gets one list with the given id.
GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error)
// GetListsByIDs fetches all lists with the provided IDs.
GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error)
// GetListsForAccountID gets all lists owned by the given accountID. // GetListsForAccountID gets all lists owned by the given accountID.
GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error)
@ -46,6 +49,9 @@ type List interface {
// GetListEntryByID gets one list entry with the given ID. // GetListEntryByID gets one list entry with the given ID.
GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error)
// GetListEntriesyIDs fetches all list entries with the provided IDs.
GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error)
// GetListEntries gets list entries from the given listID, using the given parameters. // GetListEntries gets list entries from the given listID, using the given parameters.
GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error) GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error)

View File

@ -33,6 +33,9 @@ type Notification interface {
// GetNotification returns one notification according to its id. // GetNotification returns one notification according to its id.
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error)
// GetNotificationsByIDs returns a slice of notifications of the the provided IDs.
GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error)
// GetNotification gets one notification according to the provided parameters, if it exists. // GetNotification gets one notification according to the provided parameters, if it exists.
// Since not all notifications are about a status, statusID can be an empty string. // Since not all notifications are about a status, statusID can be an empty string.
GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error) GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error)

View File

@ -107,19 +107,21 @@ func (d *Dereferencer) EnrichAnnounce(
// All good baby. // All good baby.
case errors.Is(err, db.ErrAlreadyExists): case errors.Is(err, db.ErrAlreadyExists):
uri := boost.URI
// DATA RACE! We likely lost out to another goroutine // DATA RACE! We likely lost out to another goroutine
// in a call to db.Put(Status). Look again in DB by URI. // in a call to db.Put(Status). Look again in DB by URI.
boost, err = d.state.DB.GetStatusByURI(ctx, boost.URI) boost, err = d.state.DB.GetStatusByURI(ctx, uri)
if err != nil { if err != nil {
err = gtserror.Newf( return nil, gtserror.Newf(
"error getting boost wrapper status %s from database after race: %w", "error getting boost wrapper status %s from database after race: %w",
boost.URI, err, uri, err,
) )
} }
default: default:
// Proper database error. // Proper database error.
err = gtserror.Newf("db error inserting status: %w", err) return nil, gtserror.Newf("db error inserting status: %w", err)
} }
return boost, err return boost, err

View File

@ -79,9 +79,7 @@ func (suite *AnnounceTestSuite) TestAnnounceTwice() {
// Insert the boost-of status into the // Insert the boost-of status into the
// DB cache to emulate processor handling // DB cache to emulate processor handling
boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt) boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt)
suite.state.Caches.GTS.Status().Store(boost, func() error { suite.state.Caches.GTS.Status.Put(boost)
return nil
})
// only the URI will be set for the boosted status // only the URI will be set for the boosted status
// because it still needs to be dereferenced // because it still needs to be dereferenced

View File

@ -55,7 +55,6 @@ func (m *Manager) RefetchEmojis(ctx context.Context, domain string, dereferenceM
emojis, err := m.state.DB.GetEmojisBy(ctx, domain, false, true, "", maxShortcodeDomain, "", 20) emojis, err := m.state.DB.GetEmojisBy(ctx, domain, false, true, "", maxShortcodeDomain, "", 20)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
// an actual error has occurred
log.Errorf(ctx, "error fetching emojis from database: %s", err) log.Errorf(ctx, "error fetching emojis from database: %s", err)
} }
break break

View File

@ -229,6 +229,7 @@ func (p *Processor) processMediaIDs(ctx context.Context, form *apimodel.Advanced
attachments := []*gtsmodel.MediaAttachment{} attachments := []*gtsmodel.MediaAttachment{}
attachmentIDs := []string{} attachmentIDs := []string{}
for _, mediaID := range form.MediaIDs { for _, mediaID := range form.MediaIDs {
attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaID) attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaID)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {

View File

@ -82,7 +82,7 @@ func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*ur
// Attempt to deliver data to recipient. // Attempt to deliver data to recipient.
if err := t.deliver(ctx, b, to); err != nil { if err := t.deliver(ctx, b, to); err != nil {
mutex.Lock() // safely append err to accumulator. mutex.Lock() // safely append err to accumulator.
errs.Appendf("error delivering to %s: %v", to, err) errs.Appendf("error delivering to %s: %w", to, err)
mutex.Unlock() mutex.Unlock()
} }
} }

View File

@ -36,7 +36,8 @@ import (
func (t *transport) webfingerURLFor(targetDomain string) (string, bool) { func (t *transport) webfingerURLFor(targetDomain string) (string, bool) {
url := "https://" + targetDomain + "/.well-known/webfinger" url := "https://" + targetDomain + "/.well-known/webfinger"
wc := t.controller.state.Caches.GTS.Webfinger() wc := t.controller.state.Caches.GTS.Webfinger
// We're doing the manual locking/unlocking here to be able to // We're doing the manual locking/unlocking here to be able to
// safely call Cache.Get instead of Get, as the latter updates the // safely call Cache.Get instead of Get, as the latter updates the
// item expiry which we don't want to do here // item expiry which we don't want to do here
@ -95,7 +96,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
// If we got a response we consider successful on a cached URL, i.e one set // If we got a response we consider successful on a cached URL, i.e one set
// by us later on when a host-meta based webfinger request succeeded, set it // by us later on when a host-meta based webfinger request succeeded, set it
// again here to renew the TTL // again here to renew the TTL
t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, url) t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, url)
} }
if rsp.StatusCode == http.StatusGone { if rsp.StatusCode == http.StatusGone {
return nil, fmt.Errorf("account has been deleted/is gone") return nil, fmt.Errorf("account has been deleted/is gone")
@ -151,7 +152,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
// we asked for is gone. This means the endpoint itself is valid and we should // we asked for is gone. This means the endpoint itself is valid and we should
// cache it for future queries to the same domain // cache it for future queries to the same domain
if rsp.StatusCode == http.StatusGone { if rsp.StatusCode == http.StatusGone {
t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, host) t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, host)
return nil, fmt.Errorf("account has been deleted/is gone") return nil, fmt.Errorf("account has been deleted/is gone")
} }
// We've reached the end of the line here, both the original request // We've reached the end of the line here, both the original request
@ -162,7 +163,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
// Set the URL in cache here, since host-meta told us this should be the // Set the URL in cache here, since host-meta told us this should be the
// valid one, it's different from the default and our request to it did // valid one, it's different from the default and our request to it did
// not fail in any manner // not fail in any manner
t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, host) t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, host)
return io.ReadAll(rsp.Body) return io.ReadAll(rsp.Body)
} }

View File

@ -31,7 +31,7 @@ type FingerTestSuite struct {
} }
func (suite *FingerTestSuite) TestFinger() { func (suite *FingerTestSuite) TestFinger() {
wc := suite.state.Caches.GTS.Webfinger() wc := suite.state.Caches.GTS.Webfinger
suite.Equal(0, wc.Len(), "expect webfinger cache to be empty") suite.Equal(0, wc.Len(), "expect webfinger cache to be empty")
_, err := suite.transport.Finger(context.TODO(), "brand_new_person", "unknown-instance.com") _, err := suite.transport.Finger(context.TODO(), "brand_new_person", "unknown-instance.com")
@ -43,7 +43,7 @@ func (suite *FingerTestSuite) TestFinger() {
} }
func (suite *FingerTestSuite) TestFingerWithHostMeta() { func (suite *FingerTestSuite) TestFingerWithHostMeta() {
wc := suite.state.Caches.GTS.Webfinger() wc := suite.state.Caches.GTS.Webfinger
suite.Equal(0, wc.Len(), "expect webfinger cache to be empty") suite.Equal(0, wc.Len(), "expect webfinger cache to be empty")
_, err := suite.transport.Finger(context.TODO(), "someone", "misconfigured-instance.com") _, err := suite.transport.Finger(context.TODO(), "someone", "misconfigured-instance.com")
@ -60,7 +60,7 @@ func (suite *FingerTestSuite) TestFingerWithHostMetaCacheStrategy() {
suite.T().Skip("this test is flaky on CI for as of yet unknown reasons") suite.T().Skip("this test is flaky on CI for as of yet unknown reasons")
} }
wc := suite.state.Caches.GTS.Webfinger() wc := suite.state.Caches.GTS.Webfinger
// Reset the sweep frequency so nothing interferes with the test // Reset the sweep frequency so nothing interferes with the test
wc.Stop() wc.Stop()

View File

@ -794,7 +794,6 @@ func (c *Converter) getASAttributedToAccount(ctx context.Context, id string, wit
} }
return account, nil return account, nil
} }
func (c *Converter) getASObjectAccount(ctx context.Context, id string, with ap.WithObject) (*gtsmodel.Account, error) { func (c *Converter) getASObjectAccount(ctx context.Context, id string, with ap.WithObject) (*gtsmodel.Account, error) {

View File

@ -491,7 +491,7 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// tag -- mentions // tag -- mentions
mentions := s.Mentions mentions := s.Mentions
if len(s.MentionIDs) > len(mentions) { if len(s.MentionIDs) != len(mentions) {
mentions, err = c.state.DB.GetMentions(ctx, s.MentionIDs) mentions, err = c.state.DB.GetMentions(ctx, s.MentionIDs)
if err != nil { if err != nil {
return nil, gtserror.Newf("error getting mentions: %w", err) return nil, gtserror.Newf("error getting mentions: %w", err)
@ -507,14 +507,10 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// tag -- emojis // tag -- emojis
emojis := s.Emojis emojis := s.Emojis
if len(s.EmojiIDs) > len(emojis) { if len(s.EmojiIDs) != len(emojis) {
emojis = []*gtsmodel.Emoji{} emojis, err = c.state.DB.GetEmojisByIDs(ctx, s.EmojiIDs)
for _, emojiID := range s.EmojiIDs {
emoji, err := c.state.DB.GetEmojiByID(ctx, emojiID)
if err != nil { if err != nil {
return nil, gtserror.Newf("error getting emoji %s from database: %w", emojiID, err) return nil, gtserror.Newf("error getting emojis from database: %w", err)
}
emojis = append(emojis, emoji)
} }
} }
for _, emoji := range emojis { for _, emoji := range emojis {
@ -527,7 +523,7 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// tag -- hashtags // tag -- hashtags
hashtags := s.Tags hashtags := s.Tags
if len(s.TagIDs) > len(hashtags) { if len(s.TagIDs) != len(hashtags) {
hashtags, err = c.state.DB.GetTags(ctx, s.TagIDs) hashtags, err = c.state.DB.GetTags(ctx, s.TagIDs)
if err != nil { if err != nil {
return nil, gtserror.Newf("error getting tags: %w", err) return nil, gtserror.Newf("error getting tags: %w", err)
@ -623,14 +619,10 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// attachments // attachments
attachmentProp := streams.NewActivityStreamsAttachmentProperty() attachmentProp := streams.NewActivityStreamsAttachmentProperty()
attachments := s.Attachments attachments := s.Attachments
if len(s.AttachmentIDs) > len(attachments) { if len(s.AttachmentIDs) != len(attachments) {
attachments = []*gtsmodel.MediaAttachment{} attachments, err = c.state.DB.GetAttachmentsByIDs(ctx, s.AttachmentIDs)
for _, attachmentID := range s.AttachmentIDs {
attachment, err := c.state.DB.GetAttachmentByID(ctx, attachmentID)
if err != nil { if err != nil {
return nil, gtserror.Newf("error getting attachment %s from database: %w", attachmentID, err) return nil, gtserror.Newf("error getting attachments from database: %w", err)
}
attachments = append(attachments, attachment)
} }
} }
for _, a := range attachments { for _, a := range attachments {

View File

@ -1563,20 +1563,15 @@ func (c *Converter) PollToAPIPoll(ctx context.Context, requester *gtsmodel.Accou
func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, attachments []*gtsmodel.MediaAttachment, attachmentIDs []string) ([]*apimodel.Attachment, error) { func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, attachments []*gtsmodel.MediaAttachment, attachmentIDs []string) ([]*apimodel.Attachment, error) {
var errs gtserror.MultiError var errs gtserror.MultiError
if len(attachments) == 0 { if len(attachments) == 0 && len(attachmentIDs) > 0 {
// GTS model attachments were not populated // GTS model attachments were not populated
// Preallocate expected GTS slice var err error
attachments = make([]*gtsmodel.MediaAttachment, 0, len(attachmentIDs))
// Fetch GTS models for attachment IDs // Fetch GTS models for attachment IDs
for _, id := range attachmentIDs { attachments, err = c.state.DB.GetAttachmentsByIDs(ctx, attachmentIDs)
attachment, err := c.state.DB.GetAttachmentByID(ctx, id)
if err != nil { if err != nil {
errs.Appendf("error fetching attachment %s from database: %v", id, err) errs.Appendf("error fetching attachments from database: %w", err)
continue
}
attachments = append(attachments, attachment)
} }
} }
@ -1587,7 +1582,7 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta
for _, attachment := range attachments { for _, attachment := range attachments {
apiAttachment, err := c.AttachmentToAPIAttachment(ctx, attachment) apiAttachment, err := c.AttachmentToAPIAttachment(ctx, attachment)
if err != nil { if err != nil {
errs.Appendf("error converting attchment %s to api attachment: %v", attachment.ID, err) errs.Appendf("error converting attchment %s to api attachment: %w", attachment.ID, err)
continue continue
} }
apiAttachments = append(apiAttachments, &apiAttachment) apiAttachments = append(apiAttachments, &apiAttachment)
@ -1600,20 +1595,15 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta
func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) { func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) {
var errs gtserror.MultiError var errs gtserror.MultiError
if len(emojis) == 0 { if len(emojis) == 0 && len(emojiIDs) > 0 {
// GTS model attachments were not populated // GTS model attachments were not populated
// Preallocate expected GTS slice var err error
emojis = make([]*gtsmodel.Emoji, 0, len(emojiIDs))
// Fetch GTS models for emoji IDs // Fetch GTS models for emoji IDs
for _, id := range emojiIDs { emojis, err = c.state.DB.GetEmojisByIDs(ctx, emojiIDs)
emoji, err := c.state.DB.GetEmojiByID(ctx, id)
if err != nil { if err != nil {
errs.Appendf("error fetching emoji %s from database: %v", id, err) errs.Appendf("error fetching emojis from database: %w", err)
continue
}
emojis = append(emojis, emoji)
} }
} }
@ -1624,7 +1614,7 @@ func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsm
for _, emoji := range emojis { for _, emoji := range emojis {
apiEmoji, err := c.EmojiToAPIEmoji(ctx, emoji) apiEmoji, err := c.EmojiToAPIEmoji(ctx, emoji)
if err != nil { if err != nil {
errs.Appendf("error converting emoji %s to api emoji: %v", emoji.ID, err) errs.Appendf("error converting emoji %s to api emoji: %w", emoji.ID, err)
continue continue
} }
apiEmojis = append(apiEmojis, apiEmoji) apiEmojis = append(apiEmojis, apiEmoji)
@ -1637,7 +1627,7 @@ func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsm
func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions []*gtsmodel.Mention, mentionIDs []string) ([]apimodel.Mention, error) { func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions []*gtsmodel.Mention, mentionIDs []string) ([]apimodel.Mention, error) {
var errs gtserror.MultiError var errs gtserror.MultiError
if len(mentions) == 0 { if len(mentions) == 0 && len(mentionIDs) > 0 {
var err error var err error
// GTS model mentions were not populated // GTS model mentions were not populated
@ -1645,7 +1635,7 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [
// Fetch GTS models for mention IDs // Fetch GTS models for mention IDs
mentions, err = c.state.DB.GetMentions(ctx, mentionIDs) mentions, err = c.state.DB.GetMentions(ctx, mentionIDs)
if err != nil { if err != nil {
errs.Appendf("error fetching mentions from database: %v", err) errs.Appendf("error fetching mentions from database: %w", err)
} }
} }
@ -1656,7 +1646,7 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [
for _, mention := range mentions { for _, mention := range mentions {
apiMention, err := c.MentionToAPIMention(ctx, mention) apiMention, err := c.MentionToAPIMention(ctx, mention)
if err != nil { if err != nil {
errs.Appendf("error converting mention %s to api mention: %v", mention.ID, err) errs.Appendf("error converting mention %s to api mention: %w", mention.ID, err)
continue continue
} }
apiMentions = append(apiMentions, apiMention) apiMentions = append(apiMentions, apiMention)
@ -1669,12 +1659,12 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [
func (c *Converter) convertTagsToAPITags(ctx context.Context, tags []*gtsmodel.Tag, tagIDs []string) ([]apimodel.Tag, error) { func (c *Converter) convertTagsToAPITags(ctx context.Context, tags []*gtsmodel.Tag, tagIDs []string) ([]apimodel.Tag, error) {
var errs gtserror.MultiError var errs gtserror.MultiError
if len(tags) == 0 { if len(tags) == 0 && len(tagIDs) > 0 {
var err error var err error
tags, err = c.state.DB.GetTags(ctx, tagIDs) tags, err = c.state.DB.GetTags(ctx, tagIDs)
if err != nil { if err != nil {
errs.Appendf("error fetching tags from database: %v", err) errs.Appendf("error fetching tags from database: %w", err)
} }
} }
@ -1685,7 +1675,7 @@ func (c *Converter) convertTagsToAPITags(ctx context.Context, tags []*gtsmodel.T
for _, tag := range tags { for _, tag := range tags {
apiTag, err := c.TagToAPITag(ctx, tag, false) apiTag, err := c.TagToAPITag(ctx, tag, false)
if err != nil { if err != nil {
errs.Appendf("error converting tag %s to api tag: %v", tag.ID, err) errs.Appendf("error converting tag %s to api tag: %w", tag.ID, err)
continue continue
} }
apiTags = append(apiTags, apiTag) apiTags = append(apiTags, apiTag)

View File

@ -61,3 +61,75 @@ func DeduplicateFunc[T any, C comparable](in []T, key func(v T) C) []T {
return deduped return deduped
} }
// Collate will collect the values of type K from input type []T,
// passing each item to 'get' and deduplicating the end result.
// Compared to Deduplicate() this returns []K, NOT input type []T.
func Collate[T any, K comparable](in []T, get func(T) K) []K {
ks := make([]K, 0, len(in))
km := make(map[K]struct{}, len(in))
for i := 0; i < len(in); i++ {
// Get next k.
k := get(in[i])
if _, ok := km[k]; !ok {
// New value, add
// to map + slice.
ks = append(ks, k)
km[k] = struct{}{}
}
}
return ks
}
// OrderBy orders a slice of given type by the provided alternative slice of comparable type.
func OrderBy[T any, K comparable](in []T, keys []K, key func(T) K) {
var (
start int
offset int
)
for i := 0; i < len(keys); i++ {
var (
// key at index.
k = keys[i]
// sentinel
// idx value.
idx = -1
)
// Look for model with key in slice.
for j := start; j < len(in); j++ {
if key(in[j]) == k {
idx = j
break
}
}
if idx == -1 {
// model with key
// was not found.
offset++
continue
}
// Update
// start
start++
// Expected ID index.
exp := i - offset
if idx == exp {
// Model is in expected
// location, keep going.
continue
}
// Swap models at current and expected.
in[idx], in[exp] = in[exp], in[idx]
}
}

View File

@ -39,7 +39,7 @@ func (f *Filter) AccountVisible(ctx context.Context, requester *gtsmodel.Account
requesterID = requester.ID requesterID = requester.ID
} }
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform visibility lookup. // Visibility not yet cached, perform visibility lookup.
visible, err := f.isAccountVisibleTo(ctx, requester, account) visible, err := f.isAccountVisibleTo(ctx, requester, account)
if err != nil { if err != nil {

View File

@ -42,7 +42,7 @@ func (f *Filter) StatusHomeTimelineable(ctx context.Context, owner *gtsmodel.Acc
requesterID = owner.ID requesterID = owner.ID
} }
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform timeline visibility lookup. // Visibility not yet cached, perform timeline visibility lookup.
visible, err := f.isStatusHomeTimelineable(ctx, owner, status) visible, err := f.isStatusHomeTimelineable(ctx, owner, status)
if err != nil { if err != nil {

View File

@ -40,7 +40,7 @@ func (f *Filter) StatusPublicTimelineable(ctx context.Context, requester *gtsmod
requesterID = requester.ID requesterID = requester.ID
} }
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform timeline visibility lookup. // Visibility not yet cached, perform timeline visibility lookup.
visible, err := f.isStatusPublicTimelineable(ctx, requester, status) visible, err := f.isStatusPublicTimelineable(ctx, requester, status)
if err != nil { if err != nil {

View File

@ -53,7 +53,7 @@ func (f *Filter) StatusVisible(ctx context.Context, requester *gtsmodel.Account,
requesterID = requester.ID requesterID = requester.ID
} }
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform visibility lookup. // Visibility not yet cached, perform visibility lookup.
visible, err := f.isStatusVisible(ctx, requester, status) visible, err := f.isStatusVisible(ctx, requester, status)
if err != nil { if err != nil {

View File

@ -2833,20 +2833,20 @@ func NewTestFediPeople() map[string]vocab.ActivityStreamsPerson {
"image/png", "image/png",
false, false,
), ),
"https://example.org/users/Some_User": newAPPerson( "http://example.org/users/Some_User": newAPPerson(
URLMustParse("https://example.org/users/Some_User"), URLMustParse("http://example.org/users/Some_User"),
URLMustParse("https://example.org/users/Some_User/following"), URLMustParse("http://example.org/users/Some_User/following"),
URLMustParse("https://example.org/users/Some_User/followers"), URLMustParse("http://example.org/users/Some_User/followers"),
URLMustParse("https://example.org/users/Some_User/inbox"), URLMustParse("http://example.org/users/Some_User/inbox"),
URLMustParse("https://example.org/sharedInbox"), URLMustParse("http://example.org/sharedInbox"),
URLMustParse("https://example.org/users/Some_User/outbox"), URLMustParse("http://example.org/users/Some_User/outbox"),
URLMustParse("https://example.org/users/Some_User/collections/featured"), URLMustParse("http://example.org/users/Some_User/collections/featured"),
"Some_User", "Some_User",
"just some user, don't mind me", "just some user, don't mind me",
"Peepee poo poo", "Peepee poo poo",
URLMustParse("https://example.org/@Some_User"), URLMustParse("http://example.org/@Some_User"),
true, true,
URLMustParse("https://example.org/users/Some_User#main-key"), URLMustParse("http://example.org/users/Some_User#main-key"),
someUserPub, someUserPub,
nil, nil,
"image/jpeg", "image/jpeg",

View File

@ -1,433 +0,0 @@
package result
import (
"context"
"reflect"
_ "unsafe"
"codeberg.org/gruf/go-cache/v3/simple"
"codeberg.org/gruf/go-errors/v2"
)
// Lookup represents a struct object lookup method in the cache.
type Lookup struct {
// Name is a period ('.') separated string
// of struct fields this Key encompasses.
Name string
// AllowZero indicates whether to accept and cache
// under zero value keys, otherwise ignore them.
AllowZero bool
// Multi allows specifying a key capable of storing
// multiple results. Note this only supports invalidate.
Multi bool
}
// Cache provides a means of caching value structures, along with
// the results of attempting to load them. An example usecase of this
// cache would be in wrapping a database, allowing caching of sql.ErrNoRows.
type Cache[T any] struct {
cache simple.Cache[int64, *result] // underlying result cache
lookups structKeys // pre-determined struct lookups
invalid func(T) // store unwrapped invalidate callback.
ignore func(error) bool // determines cacheable errors
copy func(T) T // copies a Value type
next int64 // update key counter
}
// New returns a new initialized Cache, with given lookups, underlying value copy function and provided capacity.
func New[T any](lookups []Lookup, copy func(T) T, cap int) *Cache[T] {
var z T
// Determine generic type
t := reflect.TypeOf(z)
// Iteratively deref pointer type
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
// Ensure that this is a struct type
if t.Kind() != reflect.Struct {
panic("generic parameter type must be struct (or ptr to)")
}
// Allocate new cache object
c := new(Cache[T])
c.copy = copy // use copy fn.
c.lookups = make([]structKey, len(lookups))
for i, lookup := range lookups {
// Create keyed field info for lookup
c.lookups[i] = newStructKey(lookup, t)
}
// Create and initialize underlying cache
c.cache.Init(0, cap)
c.SetEvictionCallback(nil)
c.SetInvalidateCallback(nil)
c.IgnoreErrors(nil)
return c
}
// SetEvictionCallback sets the eviction callback to the provided hook.
func (c *Cache[T]) SetEvictionCallback(hook func(T)) {
if hook == nil {
// Ensure non-nil hook.
hook = func(T) {}
}
c.cache.SetEvictionCallback(func(pkey int64, res *result) {
c.cache.Lock()
for _, key := range res.Keys {
// Delete key->pkey lookup
pkeys := key.info.pkeys
delete(pkeys, key.key)
}
c.cache.Unlock()
if res.Error != nil {
// Skip value hooks
putResult(res)
return
}
// Free result and call hook.
v := res.Value.(T)
putResult(res)
hook(v)
})
}
// SetInvalidateCallback sets the invalidate callback to the provided hook.
func (c *Cache[T]) SetInvalidateCallback(hook func(T)) {
if hook == nil {
// Ensure non-nil hook.
hook = func(T) {}
} // store hook.
c.invalid = hook
c.cache.SetInvalidateCallback(func(pkey int64, res *result) {
c.cache.Lock()
for _, key := range res.Keys {
// Delete key->pkey lookup
pkeys := key.info.pkeys
delete(pkeys, key.key)
}
c.cache.Unlock()
if res.Error != nil {
// Skip value hooks
putResult(res)
return
}
// Free result and call hook.
v := res.Value.(T)
putResult(res)
hook(v)
})
}
// IgnoreErrors allows setting a function hook to determine which error types should / not be cached.
func (c *Cache[T]) IgnoreErrors(ignore func(error) bool) {
if ignore == nil {
ignore = func(err error) bool {
return errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded)
}
}
c.cache.Lock()
c.ignore = ignore
c.cache.Unlock()
}
// Load will attempt to load an existing result from the cacche for the given lookup and key parts, else calling the provided load function and caching the result.
func (c *Cache[T]) Load(lookup string, load func() (T, error), keyParts ...any) (T, error) {
info := c.lookups.get(lookup)
key := info.genKey(keyParts)
return c.load(info, key, load)
}
// Has checks the cache for a positive result under the given lookup and key parts.
func (c *Cache[T]) Has(lookup string, keyParts ...any) bool {
info := c.lookups.get(lookup)
key := info.genKey(keyParts)
return c.has(info, key)
}
// Store will call the given store function, and on success store the value in the cache as a positive result.
func (c *Cache[T]) Store(value T, store func() error) error {
// Attempt to store this value.
if err := store(); err != nil {
return err
}
// Prepare cached result.
result := getResult()
result.Keys = c.lookups.generate(value)
result.Value = c.copy(value)
result.Error = nil
var evict func()
// Lock cache.
c.cache.Lock()
defer func() {
// Unlock cache.
c.cache.Unlock()
if evict != nil {
// Call evict.
evict()
}
// Call invalidate.
c.invalid(value)
}()
// Store result in cache.
evict = c.store(result)
return nil
}
// Invalidate will invalidate any result from the cache found under given lookup and key parts.
func (c *Cache[T]) Invalidate(lookup string, keyParts ...any) {
info := c.lookups.get(lookup)
key := info.genKey(keyParts)
c.invalidate(info, key)
}
// Clear empties the cache, calling the invalidate callback where necessary.
func (c *Cache[T]) Clear() { c.Trim(100) }
// Trim ensures the cache stays within percentage of total capacity, truncating where necessary.
func (c *Cache[T]) Trim(perc float64) { c.cache.Trim(perc) }
func (c *Cache[T]) load(lookup *structKey, key string, load func() (T, error)) (T, error) {
if !lookup.unique { // ensure this lookup only returns 1 result
panic("non-unique lookup does not support load: " + lookup.name)
}
var (
zero T
res *result
)
// Acquire cache lock
c.cache.Lock()
// Look for primary key for cache key (only accept len=1)
if pkeys := lookup.pkeys[key]; len(pkeys) == 1 {
// Fetch the result for primary key
entry, ok := c.cache.Cache.Get(pkeys[0])
if ok {
// Since the invalidation / eviction hooks acquire a mutex
// lock separately, and only at this point are the pkeys
// updated, there is a chance that a primary key may return
// no matching entry. Hence we have to check for it here.
res = entry.Value.(*result)
}
}
// Done with lock
c.cache.Unlock()
if res == nil {
// Generate fresh result.
value, err := load()
if err != nil {
if c.ignore(err) {
// don't cache this error type
return zero, err
}
// Alloc result.
res = getResult()
// Store error result.
res.Error = err
// This load returned an error, only
// store this item under provided key.
res.Keys = []cacheKey{{
info: lookup,
key: key,
}}
} else {
// Alloc result.
res = getResult()
// Store value result.
res.Value = value
// This was a successful load, generate keys.
res.Keys = c.lookups.generate(res.Value)
}
var evict func()
// Lock cache.
c.cache.Lock()
defer func() {
// Unlock cache.
c.cache.Unlock()
if evict != nil {
// Call evict.
evict()
}
}()
// Store result in cache.
evict = c.store(res)
}
// Catch and return cached error
if err := res.Error; err != nil {
return zero, err
}
// Copy value from cached result.
v := c.copy(res.Value.(T))
return v, nil
}
func (c *Cache[T]) has(lookup *structKey, key string) bool {
var res *result
// Acquire cache lock
c.cache.Lock()
// Look for primary key for cache key (only accept len=1)
if pkeys := lookup.pkeys[key]; len(pkeys) == 1 {
// Fetch the result for primary key
entry, ok := c.cache.Cache.Get(pkeys[0])
if ok {
// Since the invalidation / eviction hooks acquire a mutex
// lock separately, and only at this point are the pkeys
// updated, there is a chance that a primary key may return
// no matching entry. Hence we have to check for it here.
res = entry.Value.(*result)
}
}
// Check for result AND non-error result.
ok := (res != nil && res.Error == nil)
// Done with lock
c.cache.Unlock()
return ok
}
func (c *Cache[T]) store(res *result) (evict func()) {
var toEvict []*result
// Get primary key
res.PKey = c.next
c.next++
if res.PKey > c.next {
panic("cache primary key overflow")
}
for _, key := range res.Keys {
// Look for cache primary keys.
pkeys := key.info.pkeys[key.key]
if key.info.unique && len(pkeys) > 0 {
for _, conflict := range pkeys {
// Get the overlapping result with this key.
entry, ok := c.cache.Cache.Get(conflict)
if !ok {
// Since the invalidation / eviction hooks acquire a mutex
// lock separately, and only at this point are the pkeys
// updated, there is a chance that a primary key may return
// no matching entry. Hence we have to check for it here.
continue
}
// From conflicting entry, drop this key, this
// will prevent eviction cleanup key confusion.
confRes := entry.Value.(*result)
confRes.Keys.drop(key.info.name)
if len(res.Keys) == 0 {
// We just over-wrote the only lookup key for
// this value, so we drop its primary key too.
_ = c.cache.Cache.Delete(conflict)
// Add finished result to evict queue.
toEvict = append(toEvict, confRes)
}
}
// Drop existing.
pkeys = pkeys[:0]
}
// Store primary key lookup.
pkeys = append(pkeys, res.PKey)
key.info.pkeys[key.key] = pkeys
}
// Acquire new cache entry.
entry := simple.GetEntry()
entry.Key = res.PKey
entry.Value = res
evictFn := func(_ int64, entry *simple.Entry) {
// on evict during set, store evicted result.
toEvict = append(toEvict, entry.Value.(*result))
}
// Store main entry under primary key, catch evicted.
c.cache.Cache.SetWithHook(res.PKey, entry, evictFn)
if len(toEvict) == 0 {
// none evicted.
return nil
}
return func() {
for i := range toEvict {
// Rescope result.
res := toEvict[i]
// Call evict hook on each entry.
c.cache.Evict(res.PKey, res)
}
}
}
func (c *Cache[T]) invalidate(lookup *structKey, key string) {
// Look for primary key for cache key
c.cache.Lock()
pkeys := lookup.pkeys[key]
delete(lookup.pkeys, key)
c.cache.Unlock()
// Invalidate all primary keys.
c.cache.InvalidateAll(pkeys...)
}
type result struct {
// Result primary key
PKey int64
// keys accessible under
Keys cacheKeys
// cached value
Value any
// cached error
Error error
}

View File

@ -1,282 +0,0 @@
package result
import (
"fmt"
"reflect"
"strings"
"sync"
"unicode"
"unicode/utf8"
"codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-mangler"
)
// structKeys provides convience methods for a list
// of structKey field combinations used for cache keys.
type structKeys []structKey
// get fetches the structKey info for given lookup name (else, panics).
func (sk structKeys) get(name string) *structKey {
for i := range sk {
if sk[i].name == name {
return &sk[i]
}
}
panic("unknown lookup: \"" + name + "\"")
}
// generate will calculate and produce a slice of cache keys the given value
// can be stored under in the, as determined by receiving struct keys.
func (sk structKeys) generate(a any) []cacheKey {
var keys []cacheKey
// Get reflected value in order
// to access the struct fields
v := reflect.ValueOf(a)
// Iteratively deref pointer value
for v.Kind() == reflect.Pointer {
if v.IsNil() {
panic("nil ptr")
}
v = v.Elem()
}
// Acquire buffer
buf := getBuf()
outer:
for i := range sk {
// Reset buffer
buf.Reset()
// Append each field value to buffer.
for _, field := range sk[i].fields {
fv := v.Field(field.index)
fi := fv.Interface()
// Mangle this key part into buffer.
ok := field.manglePart(buf, fi)
if !ok {
// don't generate keys
// for zero value parts.
continue outer
}
// Append part separator.
buf.B = append(buf.B, '.')
}
// Drop last '.'
buf.Truncate(1)
// Append new cached key to slice
keys = append(keys, cacheKey{
info: &sk[i],
key: string(buf.B), // copy
})
}
// Release buf
putBuf(buf)
return keys
}
type cacheKeys []cacheKey
// drop will drop the cachedKey with lookup name from receiving cacheKeys slice.
func (ck *cacheKeys) drop(name string) {
_ = *ck // move out of loop
for i := range *ck {
if (*ck)[i].info.name == name {
(*ck) = append((*ck)[:i], (*ck)[i+1:]...)
break
}
}
}
// cacheKey represents an actual cached key.
type cacheKey struct {
// info is a reference to the structKey this
// cacheKey is representing. This is a shared
// reference and as such only the structKey.pkeys
// lookup map is expecting to be modified.
info *structKey
// value is the actual string representing
// this cache key for hashmap lookups.
key string
}
// structKey represents a list of struct fields
// encompassing a single cache key, the string name
// of the lookup, the lookup map to primary cache
// keys, and the key's possible zero value string.
type structKey struct {
// name is the provided cache lookup name for
// this particular struct key, consisting of
// period ('.') separated struct field names.
name string
// unique determines whether this structKey supports
// multiple or just the singular unique result.
unique bool
// fields is a slice of runtime struct field
// indices, of fields encompassed by this key.
fields []structField
// pkeys is a lookup of stored struct key values
// to the primary cache lookup key (int64). this
// is protected by the main cache mutex.
pkeys map[string][]int64
}
// newStructKey will generate a structKey{} information object for user-given lookup
// key information, and the receiving generic paramter's type information. Panics on error.
func newStructKey(lk Lookup, t reflect.Type) structKey {
var sk structKey
// Set the lookup name
sk.name = lk.Name
// Split dot-separated lookup to get
// the individual struct field names
names := strings.Split(lk.Name, ".")
// Allocate the mangler and field indices slice.
sk.fields = make([]structField, len(names))
for i, name := range names {
// Get field info for given name
ft, ok := t.FieldByName(name)
if !ok {
panic("no field found for name: \"" + name + "\"")
}
// Check field is usable
if !isExported(name) {
panic("field must be exported")
}
// Set the runtime field index
sk.fields[i].index = ft.Index[0]
// Allocate new instance of field
v := reflect.New(ft.Type)
v = v.Elem()
// Fetch mangler for field type.
sk.fields[i].mangle = mangler.Get(ft.Type)
if !lk.AllowZero {
// Append the mangled zero value interface
zero := sk.fields[i].mangle(nil, v.Interface())
sk.fields[i].zero = string(zero)
}
}
// Set unique lookup flag.
sk.unique = !lk.Multi
// Allocate primary lookup map
sk.pkeys = make(map[string][]int64)
return sk
}
// genKey generates a cache key string for given key parts (i.e. serializes them using "go-mangler").
func (sk *structKey) genKey(parts []any) string {
// Check this expected no. key parts.
if len(parts) != len(sk.fields) {
panic(fmt.Sprintf("incorrect no. key parts provided: want=%d received=%d", len(parts), len(sk.fields)))
}
// Acquire buffer
buf := getBuf()
buf.Reset()
for i, part := range parts {
// Mangle this key part into buffer.
// specifically ignoring whether this
// is returning a zero value key part.
_ = sk.fields[i].manglePart(buf, part)
// Append part separator.
buf.B = append(buf.B, '.')
}
// Drop last '.'
buf.Truncate(1)
// Create str copy
str := string(buf.B)
// Release buf
putBuf(buf)
return str
}
type structField struct {
// index is the reflect index of this struct field.
index int
// zero is the possible zero value for this
// key part. if set, this will _always_ be
// non-empty due to how the mangler works.
//
// i.e. zero = "" --> allow zero value keys
// zero != "" --> don't allow zero value keys
zero string
// mangle is the mangler function for
// serializing values of this struct field.
mangle mangler.Mangler
}
// manglePart ...
func (field *structField) manglePart(buf *byteutil.Buffer, part any) bool {
// Start of part bytes.
start := len(buf.B)
// Mangle this key part into buffer.
buf.B = field.mangle(buf.B, part)
// End of part bytes.
end := len(buf.B)
// Return whether this is zero value.
return (field.zero == "" ||
string(buf.B[start:end]) != field.zero)
}
// isExported checks whether function name is exported.
func isExported(fnName string) bool {
r, _ := utf8.DecodeRuneInString(fnName)
return unicode.IsUpper(r)
}
// bufpool provides a memory pool of byte
// buffers use when encoding key types.
var bufPool = sync.Pool{
New: func() any {
return &byteutil.Buffer{B: make([]byte, 0, 512)}
},
}
// getBuf acquires a byte buffer from memory pool.
func getBuf() *byteutil.Buffer {
return bufPool.Get().(*byteutil.Buffer)
}
// putBuf replaces a byte buffer back in memory pool.
func putBuf(buf *byteutil.Buffer) {
if buf.Cap() > int(^uint16(0)) {
return // drop large bufs
}
bufPool.Put(buf)
}

View File

@ -1,25 +0,0 @@
package result
import "sync"
// resultPool is a global pool for result
// objects, regardless of cache type.
var resultPool sync.Pool
// getEntry fetches a result from pool, or allocates new.
func getResult() *result {
v := resultPool.Get()
if v == nil {
return new(result)
}
return v.(*result)
}
// putResult replaces a result in the pool.
func putResult(r *result) {
r.PKey = 0
r.Keys = nil
r.Value = nil
r.Error = nil
resultPool.Put(r)
}

9
vendor/codeberg.org/gruf/go-structr/LICENSE generated vendored Normal file
View File

@ -0,0 +1,9 @@
MIT License
Copyright (c) gruf
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

5
vendor/codeberg.org/gruf/go-structr/README.md generated vendored Normal file
View File

@ -0,0 +1,5 @@
# go-structr
A performant struct caching library with automated indexing by arbitrary combinations of fields, including support for negative results (errors!). An example use case is in database lookups.
This is a core underpinning of [GoToSocial](https://github.com/superseriousbusiness/gotosocial)'s performance.

731
vendor/codeberg.org/gruf/go-structr/cache.go generated vendored Normal file
View File

@ -0,0 +1,731 @@
package structr
import (
"context"
"errors"
"reflect"
"sync"
)
// DefaultIgnoreErr is the default function used to
// ignore (i.e. not cache) incoming error results during
// Load() calls. By default ignores context pkg errors.
func DefaultIgnoreErr(err error) bool {
return errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded)
}
// Config defines config variables
// for initializing a struct cache.
type Config[StructType any] struct {
// Indices defines indices to create
// in the Cache for the receiving
// generic struct type parameter.
Indices []IndexConfig
// MaxSize defines the maximum number
// of results allowed in the Cache at
// one time, before old results start
// getting evicted.
MaxSize int
// IgnoreErr defines which errors to
// ignore (i.e. not cache) returned
// from load function callback calls.
// This may be left as nil, on which
// DefaultIgnoreErr will be used.
IgnoreErr func(error) bool
// CopyValue provides a means of copying
// cached values, to ensure returned values
// do not share memory with those in cache.
CopyValue func(StructType) StructType
// Invalidate is called when cache values
// (NOT errors) are invalidated, either
// as the values passed to Put() / Store(),
// or by the keys by calls to Invalidate().
Invalidate func(StructType)
}
// Cache provides a structure cache with automated
// indexing and lookups by any initialization-defined
// combination of fields (as long as serialization is
// supported by codeberg.org/gruf/go-mangler). This
// also supports caching of negative results by errors
// returned from the LoadOne() series of functions.
type Cache[StructType any] struct {
// indices used in storing passed struct
// types by user defined sets of fields.
indices []Index[StructType]
// keeps track of all indexed results,
// in order of last recently used (LRU).
lruList list[*result[StructType]]
// memory pools of common types.
llsPool []*list[*result[StructType]]
resPool []*result[StructType]
keyPool []*indexkey[StructType]
// max cache size, imposes size
// limit on the lruList in order
// to evict old entries.
maxSize int
// hook functions.
ignore func(error) bool
copy func(StructType) StructType
invalid func(StructType)
// protective mutex, guards:
// - Cache{}.lruList
// - Index{}.data
// - Cache{} hook fns
// - Cache{} pools
mutex sync.Mutex
}
// Init initializes the cache with given configuration
// including struct fields to index, and necessary fns.
func (c *Cache[T]) Init(config Config[T]) {
if len(config.Indices) == 0 {
panic("no indices provided")
}
if config.IgnoreErr == nil {
config.IgnoreErr = DefaultIgnoreErr
}
if config.CopyValue == nil {
panic("copy value function must be provided")
}
if config.MaxSize < 2 {
panic("minimum cache size is 2 for LRU to work")
}
// Safely copy over
// provided config.
c.mutex.Lock()
c.indices = make([]Index[T], len(config.Indices))
for i, config := range config.Indices {
c.indices[i].init(config)
}
c.ignore = config.IgnoreErr
c.copy = config.CopyValue
c.invalid = config.Invalidate
c.maxSize = config.MaxSize
c.mutex.Unlock()
}
// Index selects index with given name from cache, else panics.
func (c *Cache[T]) Index(name string) *Index[T] {
for i := range c.indices {
if c.indices[i].name == name {
return &c.indices[i]
}
}
panic("unknown index: " + name)
}
// GetOne fetches one value from the cache stored under index, using key generated from key parts.
// Note that given number of key parts MUST match expected number and types of the given index name.
func (c *Cache[T]) GetOne(index string, keyParts ...any) (T, bool) {
// Get index with name.
idx := c.Index(index)
// Generate index key from provided parts.
key, ok := idx.keygen.FromParts(keyParts...)
if !ok {
var zero T
return zero, false
}
// Fetch one value for key.
return c.GetOneBy(idx, key)
}
// GetOneBy fetches value from cache stored under index, using precalculated index key.
func (c *Cache[T]) GetOneBy(index *Index[T], key string) (T, bool) {
if index == nil {
panic("no index given")
} else if !index.unique {
panic("cannot get one by non-unique index")
}
values := c.GetBy(index, key)
if len(values) == 0 {
var zero T
return zero, false
}
return values[0], true
}
// Get fetches values from the cache stored under index, using keys generated from given key parts.
// Note that each number of key parts MUST match expected number and types of the given index name.
func (c *Cache[T]) Get(index string, keysParts ...[]any) []T {
// Get index with name.
idx := c.Index(index)
// Preallocate expected keys slice length.
keys := make([]string, 0, len(keysParts))
// Acquire buf.
buf := getBuf()
for _, parts := range keysParts {
// Reset buf.
buf.Reset()
// Generate key from provided parts into buffer.
if !idx.keygen.AppendFromParts(buf, parts...) {
continue
}
// Get string copy of
// genarated idx key.
key := string(buf.B)
// Append key to keys.
keys = append(keys, key)
}
// Done with buf.
putBuf(buf)
// Continue fetching values.
return c.GetBy(idx, keys...)
}
// GetBy fetches values from the cache stored under index, using precalculated index keys.
func (c *Cache[T]) GetBy(index *Index[T], keys ...string) []T {
if index == nil {
panic("no index given")
}
// Preallocate a slice of est. len.
values := make([]T, 0, len(keys))
// Acquire lock.
c.mutex.Lock()
// Check cache init.
if c.copy == nil {
c.mutex.Unlock()
panic("not initialized")
}
// Check index for all keys.
for _, key := range keys {
// Get indexed results.
list := index.data[key]
if list != nil {
// Concatenate all results with values.
list.rangefn(func(e *elem[*result[T]]) {
if e.Value.err != nil {
return
}
// Append a copy of value.
value := c.copy(e.Value.value)
values = append(values, value)
// Push to front of LRU list, USING
// THE RESULT'S LRU ENTRY, NOT THE
// INDEX KEY ENTRY. VERY IMPORTANT!!
c.lruList.moveFront(&e.Value.entry)
})
}
}
// Done with lock.
c.mutex.Unlock()
return values
}
// Put will insert the given values into cache,
// calling any invalidate hook on each value.
func (c *Cache[T]) Put(values ...T) {
// Acquire lock.
c.mutex.Lock()
// Get func ptrs.
invalid := c.invalid
// Check cache init.
if c.copy == nil {
c.mutex.Unlock()
panic("not initialized")
}
// Store all the passed values.
for _, value := range values {
c.store(nil, "", value, nil)
}
// Done with lock.
c.mutex.Unlock()
if invalid != nil {
// Pass all invalidated values
// to given user hook (if set).
for _, value := range values {
invalid(value)
}
}
}
// LoadOne fetches one result from the cache stored under index, using key generated from key parts.
// In the case that no result is found, the provided load callback will be used to hydrate the cache.
// Note that given number of key parts MUST match expected number and types of the given index name.
func (c *Cache[T]) LoadOne(index string, load func() (T, error), keyParts ...any) (T, error) {
// Get index with name.
idx := c.Index(index)
// Generate cache from from provided parts.
key, _ := idx.keygen.FromParts(keyParts...)
// Continue loading this result.
return c.LoadOneBy(idx, load, key)
}
// LoadOneBy fetches one result from the cache stored under index, using precalculated index key.
// In the case that no result is found, provided load callback will be used to hydrate the cache.
func (c *Cache[T]) LoadOneBy(index *Index[T], load func() (T, error), key string) (T, error) {
if index == nil {
panic("no index given")
} else if !index.unique {
panic("cannot get one by non-unique index")
}
var (
// whether a result was found
// (and so val / err are set).
ok bool
// separate value / error ptrs
// as the result is liable to
// change outside of lock.
val T
err error
)
// Acquire lock.
c.mutex.Lock()
// Get func ptrs.
ignore := c.ignore
// Check init'd.
if c.copy == nil ||
ignore == nil {
c.mutex.Unlock()
panic("not initialized")
}
// Get indexed results.
list := index.data[key]
if ok = (list != nil && list.head != nil); ok {
e := list.head
// Extract val / err.
val = e.Value.value
err = e.Value.err
if err == nil {
// We only ever ret
// a COPY of value.
val = c.copy(val)
}
// Push to front of LRU list, USING
// THE RESULT'S LRU ENTRY, NOT THE
// INDEX KEY ENTRY. VERY IMPORTANT!!
c.lruList.moveFront(&e.Value.entry)
}
// Done with lock.
c.mutex.Unlock()
if ok {
// result found!
return val, err
}
// Load new result.
val, err = load()
// Check for ignored
// (transient) errors.
if ignore(err) {
return val, err
}
// Acquire lock.
c.mutex.Lock()
// Index this new loaded result.
// Note this handles copying of
// the provided value, so it is
// safe for us to return as-is.
c.store(index, key, val, err)
// Done with lock.
c.mutex.Unlock()
return val, err
}
// Load fetches values from the cache stored under index, using keys generated from given key parts. The provided get callback is used
// to load groups of values from the cache by the key generated from the key parts provided to the inner callback func, where the returned
// boolean indicates whether any values are currently stored. After the get callback has returned, the cache will then call provided load
// callback to hydrate the cache with any other values. Example usage here is that you may see which values are cached using 'get', and load
// the remaining uncached values using 'load', to minimize database queries. Cached error results are not included or returned by this func.
// Note that given number of key parts MUST match expected number and types of the given index name, in those provided to the get callback.
func (c *Cache[T]) Load(index string, get func(load func(keyParts ...any) bool), load func() ([]T, error)) (values []T, err error) {
return c.LoadBy(c.Index(index), get, load)
}
// LoadBy fetches values from the cache stored under index, using precalculated index key. The provided get callback is used to load
// groups of values from the cache by the key generated from the key parts provided to the inner callback func, where the returned boolea
// indicates whether any values are currently stored. After the get callback has returned, the cache will then call provided load callback
// to hydrate the cache with any other values. Example usage here is that you may see which values are cached using 'get', and load the
// remaining uncached values using 'load', to minimize database queries. Cached error results are not included or returned by this func.
// Note that given number of key parts MUST match expected number and types of the given index name, in those provided to the get callback.
func (c *Cache[T]) LoadBy(index *Index[T], get func(load func(keyParts ...any) bool), load func() ([]T, error)) (values []T, err error) {
if index == nil {
panic("no index given")
}
// Acquire lock.
c.mutex.Lock()
// Check init'd.
if c.copy == nil {
c.mutex.Unlock()
panic("not initialized")
}
var unlocked bool
defer func() {
// Deferred unlock to catch
// any user function panics.
if !unlocked {
c.mutex.Unlock()
}
}()
// Acquire buf.
buf := getBuf()
// Pass cache check to user func.
get(func(keyParts ...any) bool {
// Reset buf.
buf.Reset()
// Generate index key from provided key parts.
if !index.keygen.AppendFromParts(buf, keyParts...) {
return false
}
// Get temp generated key str,
// (not needed after return).
keyStr := buf.String()
// Get all indexed results.
list := index.data[keyStr]
if list != nil && list.len > 0 {
// Value length before
// any below appends.
before := len(values)
// Concatenate all results with values.
list.rangefn(func(e *elem[*result[T]]) {
if e.Value.err != nil {
return
}
// Append a copy of value.
value := c.copy(e.Value.value)
values = append(values, value)
// Push to front of LRU list, USING
// THE RESULT'S LRU ENTRY, NOT THE
// INDEX KEY ENTRY. VERY IMPORTANT!!
c.lruList.moveFront(&e.Value.entry)
})
// Only if values changed did
// we actually find anything.
return len(values) != before
}
return false
})
// Done with buf.
putBuf(buf)
// Done with lock.
c.mutex.Unlock()
unlocked = true
// Load uncached values.
uncached, err := load()
if err != nil {
return nil, err
}
// Insert uncached.
c.Put(uncached...)
// Append uncached to return values.
values = append(values, uncached...)
return
}
// Store will call the given store callback, on non-error then
// passing the provided value to the Put() function. On error
// return the value is still passed to stored invalidate hook.
func (c *Cache[T]) Store(value T, store func() error) error {
// Store value.
err := store()
if err != nil {
// Get func ptrs.
c.mutex.Lock()
invalid := c.invalid
c.mutex.Unlock()
// On error don't store
// value, but still pass
// to invalidate hook.
if invalid != nil {
invalid(value)
}
return err
}
// Store value.
c.Put(value)
return nil
}
// Invalidate generates index key from parts and invalidates all stored under it.
func (c *Cache[T]) Invalidate(index string, keyParts ...any) {
// Get index with name.
idx := c.Index(index)
// Generate cache from from provided parts.
key, ok := idx.keygen.FromParts(keyParts...)
if !ok {
return
}
// Continue invalidation.
c.InvalidateBy(idx, key)
}
// InvalidateBy invalidates all results stored under index key.
func (c *Cache[T]) InvalidateBy(index *Index[T], key string) {
if index == nil {
panic("no index given")
}
var values []T
// Acquire lock.
c.mutex.Lock()
// Get func ptrs.
invalid := c.invalid
// Delete all results under key from index, collecting
// value results and dropping them from all their indices.
index_delete(c, index, key, func(del *result[T]) {
if del.err == nil {
values = append(values, del.value)
}
c.delete(del)
})
// Done with lock.
c.mutex.Unlock()
if invalid != nil {
// Pass all invalidated values
// to given user hook (if set).
for _, value := range values {
invalid(value)
}
}
}
// Trim will truncate the cache to ensure it
// stays within given percentage of MaxSize.
func (c *Cache[T]) Trim(perc float64) {
// Acquire lock.
c.mutex.Lock()
// Calculate number of cache items to drop.
max := (perc / 100) * float64(c.maxSize)
diff := c.lruList.len - int(max)
if diff <= 0 {
// Trim not needed.
c.mutex.Unlock()
return
}
// Iterate over 'diff' results
// from back (oldest) of cache.
for i := 0; i < diff; i++ {
// Get oldest LRU element.
oldest := c.lruList.tail
if oldest == nil {
// reached end.
break
}
// Drop oldest from cache.
c.delete(oldest.Value)
}
// Done with lock.
c.mutex.Unlock()
}
// Clear empties the cache by calling .Trim(0).
func (c *Cache[T]) Clear() { c.Trim(0) }
// Clean drops unused items from its memory pools.
// Useful to free memory if cache has downsized.
func (c *Cache[T]) Clean() {
c.mutex.Lock()
c.llsPool = nil
c.resPool = nil
c.keyPool = nil
c.mutex.Unlock()
}
// Len returns the current length of cache.
func (c *Cache[T]) Len() int {
c.mutex.Lock()
l := c.lruList.len
c.mutex.Unlock()
return l
}
// Cap returns the maximum capacity (size) of cache.
func (c *Cache[T]) Cap() int {
c.mutex.Lock()
m := c.maxSize
c.mutex.Unlock()
return m
}
// store will store the given value / error result in the cache, storing it under the
// already provided index + key if provided, else generating keys from provided value.
func (c *Cache[T]) store(index *Index[T], key string, value T, err error) {
// Acquire new result.
res := result_acquire(c)
if index != nil {
// Append result to the provided
// precalculated key and its index.
index_append(c, index, key, res)
} else if err != nil {
// This is an error result without
// an index provided, nothing we
// can do here so release result.
result_release(c, res)
return
}
// Set and check the result error.
if res.err = err; res.err == nil {
// This is value result, we need to
// store it under all other indices
// other than the provided.
//
// Create COPY of value.
res.value = c.copy(value)
// Get reflected value of incoming
// value, used during cache key gen.
rvalue := reflect.ValueOf(value)
// Acquire buf.
buf := getBuf()
for i := range c.indices {
// Get current index ptr.
idx := &(c.indices[i])
if idx == index {
// Already stored under
// this index, ignore.
continue
}
// Generate key from reflect value,
// (this ignores zero value keys).
buf.Reset() // reset buf first
if !idx.keygen.appendFromRValue(buf, rvalue) {
continue
}
// Alloc key copy.
key := string(buf.B)
// Append result to index at key.
index_append(c, idx, key, res)
}
// Done with buf.
putBuf(buf)
}
if c.lruList.len > c.maxSize {
// Cache has hit max size!
// Drop the oldest element.
res := c.lruList.tail.Value
c.delete(res)
}
}
// delete will delete the given result from the cache, deleting
// it from all indices it is stored under, and main LRU list.
func (c *Cache[T]) delete(res *result[T]) {
for len(res.keys) != 0 {
// Pop indexkey at end of list.
ikey := res.keys[len(res.keys)-1]
res.keys = res.keys[:len(res.keys)-1]
// Drop this result from list at key.
index_deleteOne(c, ikey.index, ikey)
// Release ikey to pool.
indexkey_release(c, ikey)
}
// Release res to pool.
result_release(c, res)
}

41
vendor/codeberg.org/gruf/go-structr/debug.go generated vendored Normal file
View File

@ -0,0 +1,41 @@
package structr
// String returns a useful debugging repr of result.
// func (r *result[T]) String() string {
// keysbuf := getBuf()
// keysbuf.B = append(keysbuf.B, '[')
// for i := range r.keys {
// keysbuf.B = strconv.AppendQuote(keysbuf.B, r.keys[i].key)
// keysbuf.B = append(keysbuf.B, ',')
// }
// if len(keysbuf.B) > 0 {
// keysbuf.B = keysbuf.B[:len(keysbuf.B)-1]
// }
// keysbuf.B = append(keysbuf.B, ']')
// str := fmt.Sprintf("{value=%v err=%v keys=%s}", r.value, r.err, keysbuf.B)
// putBuf(keysbuf)
// return str
// }
// String returns a useful debugging repr of index.
// func (i *Index[T]) String() string {
// databuf := getBuf()
// for key, values := range i.data {
// databuf.WriteString("key")
// databuf.B = strconv.AppendQuote(databuf.B, key)
// databuf.B = append(databuf.B, '=')
// fmt.Fprintf(databuf, "%v", values)
// databuf.B = append(databuf.B, ' ')
// }
// if len(i.data) > 0 {
// databuf.B = databuf.B[:len(databuf.B)-1]
// }
// str := fmt.Sprintf("{name=%s data={%s}}", i.name, databuf.B)
// putBuf(databuf)
// return str
// }
// String returns a useful debugging repr of indexkey.
// func (i *indexkey[T]) String() string {
// return i.index.name + "[" + strconv.Quote(i.key) + "]"
// }

213
vendor/codeberg.org/gruf/go-structr/index.go generated vendored Normal file
View File

@ -0,0 +1,213 @@
package structr
import (
"strings"
)
// IndexConfig defines config variables
// for initializing a struct index.
type IndexConfig struct {
// Fields should contain a comma-separated
// list of struct fields used when generating
// keys for this index. Nested fields should
// be specified using periods. An example:
// "Username,Favorites.Color"
Fields string
// Multiple indicates whether to accept multiple
// possible values for any single index key. The
// default behaviour is to only accept one value
// and overwrite existing on any write operation.
Multiple bool
// AllowZero indicates whether to accept zero
// value fields in index keys. i.e. whether to
// index structs for this set of field values
// IF any one of those field values is the zero
// value for that type. The default behaviour
// is to skip indexing structs for this lookup
// when any of the indexing fields are zero.
AllowZero bool
}
// Index is an exposed Cache internal model, used to
// generate keys and store struct results by the init
// defined key generation configuration. This model is
// exposed to provide faster lookups in the case that
// you would like to manually provide the used index
// via the Cache.___By() series of functions, or access
// the underlying index key generator.
type Index[StructType any] struct {
// name is the actual name of this
// index, which is the unparsed
// string value of contained fields.
name string
// struct field key serializer.
keygen KeyGen[StructType]
// backing in-memory data store of
// generated index keys to result lists.
data map[string]*list[*result[StructType]]
// whether to allow
// multiple results
// per index key.
unique bool
}
// init initializes this index with the given configuration.
func (i *Index[T]) init(config IndexConfig) {
fields := strings.Split(config.Fields, ",")
i.name = config.Fields
i.keygen = NewKeyGen[T](fields, config.AllowZero)
i.unique = !config.Multiple
i.data = make(map[string]*list[*result[T]])
}
// KeyGen returns the key generator associated with this index.
func (i *Index[T]) KeyGen() *KeyGen[T] {
return &i.keygen
}
func index_append[T any](c *Cache[T], i *Index[T], key string, res *result[T]) {
// Acquire + setup indexkey.
ikey := indexkey_acquire(c)
ikey.entry.Value = res
ikey.key = key
ikey.index = i
// Append to result's indexkeys.
res.keys = append(res.keys, ikey)
// Get list at key.
l := i.data[key]
if l == nil {
// Allocate new list.
l = list_acquire(c)
i.data[key] = l
} else if i.unique {
// Remove currently
// indexed result.
old := l.head
l.remove(old)
// Get ptr to old
// result before we
// release to pool.
res := old.Value
// Drop this index's key from
// old res now not indexed here.
result_dropIndex(c, res, i)
if len(res.keys) == 0 {
// Old res now unused,
// release to mem pool.
result_release(c, res)
}
}
// Add result indexkey to
// front of results list.
l.pushFront(&ikey.entry)
}
func index_deleteOne[T any](c *Cache[T], i *Index[T], ikey *indexkey[T]) {
// Get list at key.
l := i.data[ikey.key]
if l == nil {
return
}
// Remove from list.
l.remove(&ikey.entry)
if l.len == 0 {
// Remove list from map.
delete(i.data, ikey.key)
// Release list to pool.
list_release(c, l)
}
}
func index_delete[T any](c *Cache[T], i *Index[T], key string, fn func(*result[T])) {
if fn == nil {
panic("nil fn")
}
// Get list at key.
l := i.data[key]
if l == nil {
return
}
// Delete data at key.
delete(i.data, key)
// Iterate results in list.
for x := 0; x < l.len; x++ {
// Pop current head.
res := l.head.Value
l.remove(l.head)
// Delete index's key
// from result tracking.
result_dropIndex(c, res, i)
// Call hook.
fn(res)
}
// Release list to pool.
list_release(c, l)
}
type indexkey[T any] struct {
// linked list entry the related
// result is stored under in the
// Index.data[key] linked list.
entry elem[*result[T]]
// key is the generated index key
// the related result is indexed
// under, in the below index.
key string
// index is the index that the
// related result is indexed in.
index *Index[T]
}
func indexkey_acquire[T any](c *Cache[T]) *indexkey[T] {
var ikey *indexkey[T]
if len(c.keyPool) == 0 {
// Allocate new key.
ikey = new(indexkey[T])
} else {
// Pop result from pool slice.
ikey = c.keyPool[len(c.keyPool)-1]
c.keyPool = c.keyPool[:len(c.keyPool)-1]
}
return ikey
}
func indexkey_release[T any](c *Cache[T], ikey *indexkey[T]) {
// Reset indexkey.
ikey.entry.Value = nil
ikey.key = ""
ikey.index = nil
// Release indexkey to memory pool.
c.keyPool = append(c.keyPool, ikey)
}

204
vendor/codeberg.org/gruf/go-structr/key.go generated vendored Normal file
View File

@ -0,0 +1,204 @@
package structr
import (
"reflect"
"strings"
"codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-mangler"
)
// KeyGen is the underlying index key generator
// used within Index, and therefore Cache itself.
type KeyGen[StructType any] struct {
// fields contains our representation of
// the struct fields contained in the
// creation of keys by this generator.
fields []structfield
// zero specifies whether zero
// value fields are permitted.
zero bool
}
// NewKeyGen returns a new initialized KeyGen for the receiving generic
// parameter type, comprising of the given field strings, and whether to
// allow zero values to be included within generated output strings.
func NewKeyGen[T any](fields []string, allowZero bool) KeyGen[T] {
var kgen KeyGen[T]
// Preallocate expected struct field slice.
kgen.fields = make([]structfield, len(fields))
// Get the reflected struct ptr type.
t := reflect.TypeOf((*T)(nil)).Elem()
for i, fieldName := range fields {
// Split name to account for nesting.
names := strings.Split(fieldName, ".")
// Look for a usable struct field from type.
sfield, ok := findField(t, names, allowZero)
if !ok {
panicf("failed finding field: %s", fieldName)
}
// Set parsed struct field.
kgen.fields[i] = sfield
}
// Set config flags.
kgen.zero = allowZero
return kgen
}
// FromParts generates key string from individual key parts.
func (kgen *KeyGen[T]) FromParts(parts ...any) (key string, ok bool) {
buf := getBuf()
if ok = kgen.AppendFromParts(buf, parts...); ok {
key = string(buf.B)
}
putBuf(buf)
return
}
// FromValue generates key string from a value, via reflection.
func (kgen *KeyGen[T]) FromValue(value T) (key string, ok bool) {
buf := getBuf()
rvalue := reflect.ValueOf(value)
if ok = kgen.appendFromRValue(buf, rvalue); ok {
key = string(buf.B)
}
putBuf(buf)
return
}
// AppendFromParts generates key string into provided buffer, from individual key parts.
func (kgen *KeyGen[T]) AppendFromParts(buf *byteutil.Buffer, parts ...any) bool {
if len(parts) != len(kgen.fields) {
// User must provide correct number of parts for key.
panicf("incorrect number key parts: want=%d received=%d",
len(parts),
len(kgen.fields),
)
}
if kgen.zero {
// Zero values are permitted,
// mangle all values and ignore
// zero value return booleans.
for i, part := range parts {
// Mangle this value into buffer.
_ = kgen.fields[i].Mangle(buf, part)
// Append part separator.
buf.B = append(buf.B, '.')
}
} else {
// Zero values are NOT permitted.
for i, part := range parts {
// Mangle this value into buffer.
z := kgen.fields[i].Mangle(buf, part)
if z {
// The value was zero for
// this type, return early.
return false
}
// Append part separator.
buf.B = append(buf.B, '.')
}
}
// Drop the last separator.
buf.B = buf.B[:len(buf.B)-1]
return true
}
// AppendFromValue generates key string into provided buffer, from a value via reflection.
func (kgen *KeyGen[T]) AppendFromValue(buf *byteutil.Buffer, value T) bool {
return kgen.appendFromRValue(buf, reflect.ValueOf(value))
}
// appendFromRValue is the underlying generator function for the exported ___FromValue() functions,
// accepting a reflected input. We do not expose this as the reflected value is EXPECTED to be right.
func (kgen *KeyGen[T]) appendFromRValue(buf *byteutil.Buffer, rvalue reflect.Value) bool {
// Follow any ptrs leading to value.
for rvalue.Kind() == reflect.Pointer {
rvalue = rvalue.Elem()
}
if kgen.zero {
// Zero values are permitted,
// mangle all values and ignore
// zero value return booleans.
for i := range kgen.fields {
// Get the reflect value's field at idx.
fv := rvalue.FieldByIndex(kgen.fields[i].index)
fi := fv.Interface()
// Mangle this value into buffer.
_ = kgen.fields[i].Mangle(buf, fi)
// Append part separator.
buf.B = append(buf.B, '.')
}
} else {
// Zero values are NOT permitted.
for i := range kgen.fields {
// Get the reflect value's field at idx.
fv := rvalue.FieldByIndex(kgen.fields[i].index)
fi := fv.Interface()
// Mangle this value into buffer.
z := kgen.fields[i].Mangle(buf, fi)
if z {
// The value was zero for
// this type, return early.
return false
}
// Append part separator.
buf.B = append(buf.B, '.')
}
}
// Drop the last separator.
buf.B = buf.B[:len(buf.B)-1]
return true
}
type structfield struct {
// index is the reflected index
// of this field (this takes into
// account struct nesting).
index []int
// zero is the possible mangled
// zero value for this field.
zero string
// mangler is the mangler function for
// serializing values of this field.
mangler mangler.Mangler
}
// Mangle mangles the given value, using the determined type-appropriate
// field's type. The returned boolean indicates whether this is a zero value.
func (f *structfield) Mangle(buf *byteutil.Buffer, value any) (isZero bool) {
s := len(buf.B) // start pos.
buf.B = f.mangler(buf.B, value)
e := len(buf.B) // end pos.
isZero = (f.zero == string(buf.B[s:e]))
return
}

130
vendor/codeberg.org/gruf/go-structr/list.go generated vendored Normal file
View File

@ -0,0 +1,130 @@
package structr
// elem represents an element
// in a doubly-linked list.
type elem[T any] struct {
next *elem[T]
prev *elem[T]
Value T
}
// list implements a doubly-linked list, where:
// - head = index 0 (i.e. the front)
// - tail = index n-1 (i.e. the back)
type list[T any] struct {
head *elem[T]
tail *elem[T]
len int
}
func list_acquire[T any](c *Cache[T]) *list[*result[T]] {
var l *list[*result[T]]
if len(c.llsPool) == 0 {
// Allocate new list.
l = new(list[*result[T]])
} else {
// Pop list from pool slice.
l = c.llsPool[len(c.llsPool)-1]
c.llsPool = c.llsPool[:len(c.llsPool)-1]
}
return l
}
func list_release[T any](c *Cache[T], l *list[*result[T]]) {
// Reset list.
l.head = nil
l.tail = nil
l.len = 0
// Release list to memory pool.
c.llsPool = append(c.llsPool, l)
}
// pushFront pushes new 'elem' to front of list.
func (l *list[T]) pushFront(elem *elem[T]) {
if l.len == 0 {
// Set new tail + head
l.head = elem
l.tail = elem
// Link elem to itself
elem.next = elem
elem.prev = elem
} else {
oldHead := l.head
// Link to old head
elem.next = oldHead
oldHead.prev = elem
// Link up to tail
elem.prev = l.tail
l.tail.next = elem
// Set new head
l.head = elem
}
// Incr count
l.len++
}
// moveFront calls remove() on elem, followed by pushFront().
func (l *list[T]) moveFront(elem *elem[T]) {
l.remove(elem)
l.pushFront(elem)
}
// remove removes the 'elem' from the list.
func (l *list[T]) remove(elem *elem[T]) {
if l.len <= 1 {
// Drop elem's links
elem.next = nil
elem.prev = nil
// Only elem in list
l.head = nil
l.tail = nil
l.len = 0
return
}
// Get surrounding elems
next := elem.next
prev := elem.prev
// Relink chain
next.prev = prev
prev.next = next
switch elem {
// Set new head
case l.head:
l.head = next
// Set new tail
case l.tail:
l.tail = prev
}
// Drop elem's links
elem.next = nil
elem.prev = nil
// Decr count
l.len--
}
// rangefn ranges all the elements in the list, passing each to fn.
func (l *list[T]) rangefn(fn func(*elem[T])) {
if fn == nil {
panic("nil fn")
}
elem := l.head
for i := 0; i < l.len; i++ {
fn(elem)
elem = elem.next
}
}

76
vendor/codeberg.org/gruf/go-structr/result.go generated vendored Normal file
View File

@ -0,0 +1,76 @@
package structr
type result[T any] struct {
// linked list entry this result is
// stored under in Cache.lruList.
entry elem[*result[T]]
// keys tracks the indices
// result is stored under.
keys []*indexkey[T]
// cached value.
value T
// cached error.
err error
}
func result_acquire[T any](c *Cache[T]) *result[T] {
var res *result[T]
if len(c.resPool) == 0 {
// Allocate new result.
res = new(result[T])
} else {
// Pop result from pool slice.
res = c.resPool[len(c.resPool)-1]
c.resPool = c.resPool[:len(c.resPool)-1]
}
// Push to front of LRU list.
c.lruList.pushFront(&res.entry)
res.entry.Value = res
return res
}
func result_release[T any](c *Cache[T], res *result[T]) {
// Remove from the LRU list.
c.lruList.remove(&res.entry)
res.entry.Value = nil
var zero T
// Reset all result fields.
res.keys = res.keys[:0]
res.value = zero
res.err = nil
// Release result to memory pool.
c.resPool = append(c.resPool, res)
}
func result_dropIndex[T any](c *Cache[T], res *result[T], index *Index[T]) {
for i := 0; i < len(res.keys); i++ {
if res.keys[i].index != index {
// Prof. Obiwan:
// this is not the index
// we are looking for.
continue
}
// Get index key ptr.
ikey := res.keys[i]
// Move all index keys down + reslice.
copy(res.keys[i:], res.keys[i+1:])
res.keys = res.keys[:len(res.keys)-1]
// Release ikey to memory pool.
indexkey_release(c, ikey)
return
}
}

118
vendor/codeberg.org/gruf/go-structr/util.go generated vendored Normal file
View File

@ -0,0 +1,118 @@
package structr
import (
"fmt"
"reflect"
"sync"
"unicode"
"unicode/utf8"
"codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-mangler"
)
// findField will search for a struct field with given set of names, where names is a len > 0 slice of names account for nesting.
func findField(t reflect.Type, names []string, allowZero bool) (sfield structfield, ok bool) {
var (
// isExported returns whether name is exported
// from a package; can be func or struct field.
isExported = func(name string) bool {
r, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(r)
}
// popName pops the next name from
// the provided slice of field names.
popName = func() string {
// Pop next name.
name := names[0]
names = names[1:]
// Ensure valid name.
if !isExported(name) {
panicf("field is not exported: %s", name)
}
return name
}
// field is the iteratively searched-for
// struct field value in below loop.
field reflect.StructField
)
for len(names) > 0 {
// Pop next name.
name := popName()
// Follow any ptrs leading to field.
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
// The end type after following ptrs must be struct.
panicf("field %s is not struct (ptr): %s", t, name)
}
// Look for next field by name.
field, ok = t.FieldByName(name)
if !ok {
return
}
// Append next set of indices required to reach field.
sfield.index = append(sfield.index, field.Index...)
// Set the next type.
t = field.Type
}
// Get final type mangler func.
sfield.mangler = mangler.Get(t)
if allowZero {
var buf []byte
// Allocate field instance.
v := reflect.New(field.Type)
v = v.Elem()
// Serialize this zero value into buf.
buf = sfield.mangler(buf, v.Interface())
// Set zero value str.
sfield.zero = string(buf)
}
return
}
// panicf provides a panic with string formatting.
func panicf(format string, args ...any) {
panic(fmt.Sprintf(format, args...))
}
// bufpool provides a memory pool of byte
// buffers used when encoding key types.
var bufPool sync.Pool
// getBuf fetches buffer from memory pool.
func getBuf() *byteutil.Buffer {
v := bufPool.Get()
if v == nil {
buf := new(byteutil.Buffer)
buf.B = make([]byte, 0, 512)
v = buf
}
return v.(*byteutil.Buffer)
}
// putBuf replaces buffer in memory pool.
func putBuf(buf *byteutil.Buffer) {
if buf.Cap() > int(^uint16(0)) {
return // drop large bufs
}
buf.Reset()
bufPool.Put(buf)
}

4
vendor/modules.txt vendored
View File

@ -16,7 +16,6 @@ codeberg.org/gruf/go-byteutil
# codeberg.org/gruf/go-cache/v3 v3.5.7 # codeberg.org/gruf/go-cache/v3 v3.5.7
## explicit; go 1.19 ## explicit; go 1.19
codeberg.org/gruf/go-cache/v3 codeberg.org/gruf/go-cache/v3
codeberg.org/gruf/go-cache/v3/result
codeberg.org/gruf/go-cache/v3/simple codeberg.org/gruf/go-cache/v3/simple
codeberg.org/gruf/go-cache/v3/ttl codeberg.org/gruf/go-cache/v3/ttl
# codeberg.org/gruf/go-debug v1.3.0 # codeberg.org/gruf/go-debug v1.3.0
@ -60,6 +59,9 @@ codeberg.org/gruf/go-sched
## explicit; go 1.19 ## explicit; go 1.19
codeberg.org/gruf/go-store/v2/storage codeberg.org/gruf/go-store/v2/storage
codeberg.org/gruf/go-store/v2/util codeberg.org/gruf/go-store/v2/util
# codeberg.org/gruf/go-structr v0.1.1
## explicit; go 1.21
codeberg.org/gruf/go-structr
# codeberg.org/superseriousbusiness/exif-terminator v0.7.0 # codeberg.org/superseriousbusiness/exif-terminator v0.7.0
## explicit; go 1.21 ## explicit; go 1.21
codeberg.org/superseriousbusiness/exif-terminator codeberg.org/superseriousbusiness/exif-terminator