diff --git a/internal/cache/db.go b/internal/cache/db.go index fe9085613..dd4e8b212 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -18,6 +18,8 @@ package cache import ( + "sync/atomic" + "codeberg.org/gruf/go-structr" "github.com/superseriousbusiness/gotosocial/internal/cache/domain" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -136,6 +138,14 @@ type DBCaches struct { // Instance provides access to the gtsmodel Instance database cache. Instance StructCache[*gtsmodel.Instance] + // LocalInstance provides caching for + // simple + common local instance queries. + LocalInstance struct { + Domains atomic.Pointer[int] + Statuses atomic.Pointer[int] + Users atomic.Pointer[int] + } + // InteractionRequest provides access to the gtsmodel InteractionRequest database cache. InteractionRequest StructCache[*gtsmodel.InteractionRequest] @@ -849,9 +859,10 @@ func (c *Caches) initInstance() { {Fields: "ID"}, {Fields: "Domain"}, }, - MaxSize: cap, - IgnoreErr: ignoreErrors, - Copy: copyF, + MaxSize: cap, + IgnoreErr: ignoreErrors, + Copy: copyF, + Invalidate: c.OnInvalidateInstance, }) } diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go index ca12e412c..9b42e88f6 100644 --- a/internal/cache/invalidate.go +++ b/internal/cache/invalidate.go @@ -19,6 +19,7 @@ package cache import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" ) // Below are cache invalidation hooks between other caches, @@ -178,6 +179,11 @@ func (c *Caches) OnInvalidateFollowRequest(followReq *gtsmodel.FollowRequest) { ) } +func (c *Caches) OnInvalidateInstance(instance *gtsmodel.Instance) { + // Invalidate the local domains count. + c.DB.LocalInstance.Domains.Store(nil) +} + func (c *Caches) OnInvalidateList(list *gtsmodel.List) { // Invalidate list IDs cache. c.DB.ListIDs.Invalidate( @@ -255,6 +261,11 @@ func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) { // Invalidate cache of attached poll ID. c.DB.Poll.Invalidate("ID", status.PollID) } + + if util.PtrOrZero(status.Local) { + // Invalidate the local statuses count. + c.DB.LocalInstance.Statuses.Store(nil) + } } func (c *Caches) OnInvalidateStatusBookmark(bookmark *gtsmodel.StatusBookmark) { @@ -271,6 +282,9 @@ func (c *Caches) OnInvalidateUser(user *gtsmodel.User) { // Invalidate local account ID cached visibility. c.Visibility.Invalidate("ItemID", user.AccountID) c.Visibility.Invalidate("RequesterID", user.AccountID) + + // Invalidate the local users count. + c.DB.LocalInstance.Users.Store(nil) } func (c *Caches) OnInvalidateUserMute(mute *gtsmodel.UserMute) { diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 008b6c8f3..419951253 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -39,6 +39,15 @@ type instanceDB struct { } func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, error) { + localhost := (domain == config.GetHost() || domain == config.GetAccountDomain()) + + if localhost { + // Check for a cached instance user count, if so return this. + if n := i.state.Caches.DB.LocalInstance.Users.Load(); n != nil { + return *n, nil + } + } + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). @@ -46,7 +55,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int Where("? != ?", bun.Ident("account.username"), domain). Where("? IS NULL", bun.Ident("account.suspended_at")) - if domain == config.GetHost() || domain == config.GetAccountDomain() { + if localhost { // If the domain is *this* domain, just // count where the domain field is null. q = q.Where("? IS NULL", bun.Ident("account.domain")) @@ -58,15 +67,30 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int if err != nil { return 0, err } + + if localhost { + // Update cached instance users account value. + i.state.Caches.DB.LocalInstance.Users.Store(&count) + } + return count, nil } func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, error) { + localhost := (domain == config.GetHost() || domain == config.GetAccountDomain()) + + if localhost { + // Check for a cached instance statuses count, if so return this. + if n := i.state.Caches.DB.LocalInstance.Statuses.Load(); n != nil { + return *n, nil + } + } + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")) - if domain == config.GetHost() || domain == config.GetAccountDomain() { + if localhost { // if the domain is *this* domain, just count where local is true q = q.Where("? = ?", bun.Ident("status.local"), true) } else { @@ -83,15 +107,30 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( if err != nil { return 0, err } + + if localhost { + // Update cached instance statuses account value. + i.state.Caches.DB.LocalInstance.Statuses.Store(&count) + } + return count, nil } func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, error) { + localhost := (domain == config.GetHost() || domain == config.GetAccountDomain()) + + if localhost { + // Check for a cached instance domains count, if so return this. + if n := i.state.Caches.DB.LocalInstance.Domains.Load(); n != nil { + return *n, nil + } + } + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")) - if domain == config.GetHost() { + if localhost { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked q = q. @@ -106,6 +145,12 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i if err != nil { return 0, err } + + if localhost { + // Update cached instance domains account value. + i.state.Caches.DB.LocalInstance.Domains.Store(&count) + } + return count, nil } @@ -215,13 +260,15 @@ func (i *instanceDB) PopulateInstance(ctx context.Context, instance *gtsmodel.In } func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error { - // Normalize the domain as punycode var err error + + // Normalize the domain as punycode instance.Domain, err = util.Punify(instance.Domain) if err != nil { return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) } + // Store the new instance model in database, invalidating cache. return i.state.Caches.DB.Instance.Store(instance, func() error { _, err := i.db.NewInsert().Model(instance).Exec(ctx) return err diff --git a/internal/filter/visibility/status.go b/internal/filter/visibility/status.go index be59e800e..a0f971464 100644 --- a/internal/filter/visibility/status.go +++ b/internal/filter/visibility/status.go @@ -104,18 +104,20 @@ func (f *Filter) isStatusVisible( return false, nil } - if util.PtrOrValue(status.PendingApproval, false) { + if util.PtrOrZero(status.PendingApproval) { // Use a different visibility heuristic // for pending approval statuses. - return f.isPendingStatusVisible(ctx, + return isPendingStatusVisible( requester, status, - ) + ), nil } if requester == nil { // Use a different visibility // heuristic for unauthed requests. - return f.isStatusVisibleUnauthed(ctx, status) + return f.isStatusVisibleUnauthed( + ctx, status, + ) } /* @@ -210,45 +212,42 @@ func (f *Filter) isStatusVisible( } } -func (f *Filter) isPendingStatusVisible( - _ context.Context, - requester *gtsmodel.Account, - status *gtsmodel.Status, -) (bool, error) { +// isPendingStatusVisible returns whether a status pending approval is visible to requester. +func isPendingStatusVisible(requester *gtsmodel.Account, status *gtsmodel.Status) bool { if requester == nil { // Any old tom, dick, and harry can't // see pending-approval statuses, // no matter what their visibility. - return false, nil + return false } if status.AccountID == requester.ID { // This is requester's status, // so they can always see it. - return true, nil + return true } if status.InReplyToAccountID == requester.ID { // This status replies to requester, // so they can always see it (else // they can't approve it). - return true, nil + return true } if status.BoostOfAccountID == requester.ID { // This status boosts requester, // so they can always see it. - return true, nil + return true } - // Nobody else can see this. - return false, nil + // Nobody else + // can see this. + return false } -func (f *Filter) isStatusVisibleUnauthed( - ctx context.Context, - status *gtsmodel.Status, -) (bool, error) { +// isStatusVisibleUnauthed returns whether status is visible without any unauthenticated account. +func (f *Filter) isStatusVisibleUnauthed(ctx context.Context, status *gtsmodel.Status) (bool, error) { + // For remote accounts, only show // Public statuses via the web. if status.Account.IsRemote() { @@ -275,8 +274,7 @@ func (f *Filter) isStatusVisibleUnauthed( } } - webVisibility := status.Account.Settings.WebVisibility - switch webVisibility { + switch webvis := status.Account.Settings.WebVisibility; webvis { // public_only: status must be Public. case gtsmodel.VisibilityPublic: @@ -296,7 +294,7 @@ func (f *Filter) isStatusVisibleUnauthed( default: return false, gtserror.Newf( "unrecognized web visibility for account %s: %s", - status.Account.ID, webVisibility, + status.Account.ID, webvis, ) } } diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index 30ef0b04d..8a5f51c21 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -48,9 +48,6 @@ var ( // ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. ErrReservedAddr = errors.New("dial within blocked / reserved IP range") - - // ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). - ErrBodyTooLarge = errors.New("body size too large") ) // Config provides configuration details for setting up a new @@ -302,7 +299,6 @@ func (c *Client) do(r *Request) (rsp *http.Response, retry bool, err error) { if errorsv2.IsV2(err, context.DeadlineExceeded, context.Canceled, - ErrBodyTooLarge, ErrReservedAddr, ) { // Non-retryable errors. diff --git a/internal/processing/common/account.go b/internal/processing/common/account.go index 05e974513..ae26e4ebd 100644 --- a/internal/processing/common/account.go +++ b/internal/processing/common/account.go @@ -42,6 +42,7 @@ func (p *Processor) GetTargetAccountBy( // Fetch the target account from db. target, err := getTargetFromDB() if err != nil && !errors.Is(err, db.ErrNoEntries) { + err := gtserror.Newf("error getting from db: %w", err) return nil, false, gtserror.NewErrorInternalError(err) } @@ -57,6 +58,7 @@ func (p *Processor) GetTargetAccountBy( // Check whether target account is visible to requesting account. visible, err = p.visFilter.AccountVisible(ctx, requester, target) if err != nil { + err := gtserror.Newf("error checking visibility: %w", err) return nil, false, gtserror.NewErrorInternalError(err) } @@ -128,7 +130,8 @@ func (p *Processor) GetVisibleTargetAccount( return target, nil } -// GetAPIAccount fetches the appropriate API account model depending on whether requester = target. +// GetAPIAccount fetches the appropriate API account +// model depending on whether requester = target. func (p *Processor) GetAPIAccount( ctx context.Context, requester *gtsmodel.Account, @@ -148,14 +151,15 @@ func (p *Processor) GetAPIAccount( } if err != nil { - err := gtserror.Newf("error converting account: %w", err) + err := gtserror.Newf("error converting: %w", err) return nil, gtserror.NewErrorInternalError(err) } return apiAcc, nil } -// GetAPIAccountBlocked fetches the limited "blocked" account model for given target. +// GetAPIAccountBlocked fetches the limited +// "blocked" account model for given target. func (p *Processor) GetAPIAccountBlocked( ctx context.Context, targetAcc *gtsmodel.Account, @@ -165,7 +169,7 @@ func (p *Processor) GetAPIAccountBlocked( ) { apiAccount, err := p.converter.AccountToAPIAccountBlocked(ctx, targetAcc) if err != nil { - err = gtserror.Newf("error converting account: %w", err) + err := gtserror.Newf("error converting: %w", err) return nil, gtserror.NewErrorInternalError(err) } return apiAccount, nil @@ -182,7 +186,7 @@ func (p *Processor) GetAPIAccountSensitive( ) { apiAccount, err := p.converter.AccountToAPIAccountSensitive(ctx, targetAcc) if err != nil { - err = gtserror.Newf("error converting account: %w", err) + err := gtserror.Newf("error converting: %w", err) return nil, gtserror.NewErrorInternalError(err) } return apiAccount, nil @@ -226,8 +230,7 @@ func (p *Processor) getVisibleAPIAccounts( ) []*apimodel.Account { // Start new log entry with // the above calling func's name. - l := log. - WithContext(ctx). + l := log.WithContext(ctx). WithField("caller", log.Caller(calldepth+1)) // Preallocate slice according to expected length. diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go index a1d432eb0..dd83a2cc5 100644 --- a/internal/processing/common/status.go +++ b/internal/processing/common/status.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" statusfilter "github.com/superseriousbusiness/gotosocial/internal/filter/status" + "github.com/superseriousbusiness/gotosocial/internal/filter/usermute" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -50,6 +51,7 @@ func (p *Processor) GetTargetStatusBy( // Fetch the target status from db. target, err := getTargetFromDB() if err != nil && !errors.Is(err, db.ErrNoEntries) { + err := gtserror.Newf("error getting from db: %w", err) return nil, false, gtserror.NewErrorInternalError(err) } @@ -65,6 +67,7 @@ func (p *Processor) GetTargetStatusBy( // Check whether target status is visible to requesting account. visible, err = p.visFilter.StatusVisible(ctx, requester, target) if err != nil { + err := gtserror.Newf("error checking visibility: %w", err) return nil, false, gtserror.NewErrorInternalError(err) } @@ -174,14 +177,83 @@ func (p *Processor) GetAPIStatus( apiStatus *apimodel.Status, errWithCode gtserror.WithCode, ) { - apiStatus, err := p.converter.StatusToAPIStatus(ctx, target, requester, statusfilter.FilterContextNone, nil, nil) + apiStatus, err := p.converter.StatusToAPIStatus(ctx, + target, + requester, + statusfilter.FilterContextNone, + nil, + nil, + ) if err != nil { - err = gtserror.Newf("error converting status: %w", err) + err := gtserror.Newf("error converting: %w", err) return nil, gtserror.NewErrorInternalError(err) } return apiStatus, nil } +// GetVisibleAPIStatuses converts a slice of statuses to API +// model statuses, filtering according to visibility to requester +// along with given filter context, filters and user mutes. +// +// Please note that all errors will be logged at ERROR level, +// but will not be returned. Callers are likely to run into +// show-stopping errors in the lead-up to this function. +func (p *Processor) GetVisibleAPIStatuses( + ctx context.Context, + requester *gtsmodel.Account, + statuses []*gtsmodel.Status, + filterContext statusfilter.FilterContext, + filters []*gtsmodel.Filter, + userMutes []*gtsmodel.UserMute, +) []apimodel.Status { + + // Start new log entry with + // the calling function name + // as a field in each entry. + l := log.WithContext(ctx). + WithField("caller", log.Caller(3)) + + // Compile mutes to useable user mutes for type converter. + compUserMutes := usermute.NewCompiledUserMuteList(userMutes) + + // Iterate filtered statuses for conversion to API model. + apiStatuses := make([]apimodel.Status, 0, len(statuses)) + for _, status := range statuses { + + // Check whether status is visible to requester. + visible, err := p.visFilter.StatusVisible(ctx, + requester, + status, + ) + if err != nil { + l.Errorf("error checking visibility: %v", err) + continue + } + + if !visible { + continue + } + + // Convert to API status, taking mute / filter into account. + apiStatus, err := p.converter.StatusToAPIStatus(ctx, + status, + requester, + filterContext, + filters, + compUserMutes, + ) + if err != nil && !errors.Is(err, statusfilter.ErrHideStatus) { + l.Errorf("error converting: %v", err) + continue + } + + // Append converted status to return slice. + apiStatuses = append(apiStatuses, *apiStatus) + } + + return apiStatuses +} + // InvalidateTimelinedStatus is a shortcut function for invalidating the cached // representation one status in the home timeline and all list timelines of the // given accountID. It should only be called in cases where a status update diff --git a/internal/processing/status/context.go b/internal/processing/status/context.go index 9f3a7d089..19c6cac18 100644 --- a/internal/processing/status/context.go +++ b/internal/processing/status/context.go @@ -24,7 +24,6 @@ import ( apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" statusfilter "github.com/superseriousbusiness/gotosocial/internal/filter/status" - "github.com/superseriousbusiness/gotosocial/internal/filter/usermute" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -308,22 +307,7 @@ func (p *Processor) ContextGet( return nil, gtserror.NewErrorInternalError(err) } - convert := func( - ctx context.Context, - status *gtsmodel.Status, - requestingAccount *gtsmodel.Account, - ) (*apimodel.Status, error) { - return p.converter.StatusToAPIStatus( - ctx, - status, - requestingAccount, - statusfilter.FilterContextThread, - filters, - usermute.NewCompiledUserMuteList(mutes), - ) - } - - // Retrieve the thread context. + // Retrieve the full thread context. threadContext, errWithCode := p.contextGet( ctx, requester, @@ -333,34 +317,27 @@ func (p *Processor) ContextGet( return nil, errWithCode } - apiContext := &apimodel.ThreadContext{ - Ancestors: make([]apimodel.Status, 0, len(threadContext.ancestors)), - Descendants: make([]apimodel.Status, 0, len(threadContext.descendants)), - } + var apiContext apimodel.ThreadContext - // Convert ancestors + filter - // out ones that aren't visible. - for _, status := range threadContext.ancestors { - if v, err := p.visFilter.StatusVisible(ctx, requester, status); err == nil && v { - status, err := convert(ctx, status, requester) - if err == nil { - apiContext.Ancestors = append(apiContext.Ancestors, *status) - } - } - } + // Convert and filter the thread context ancestors. + apiContext.Ancestors = p.c.GetVisibleAPIStatuses(ctx, + requester, + threadContext.ancestors, + statusfilter.FilterContextThread, + filters, + mutes, + ) - // Convert descendants + filter - // out ones that aren't visible. - for _, status := range threadContext.descendants { - if v, err := p.visFilter.StatusVisible(ctx, requester, status); err == nil && v { - status, err := convert(ctx, status, requester) - if err == nil { - apiContext.Descendants = append(apiContext.Descendants, *status) - } - } - } + // Convert and filter the thread context descendants + apiContext.Descendants = p.c.GetVisibleAPIStatuses(ctx, + requester, + threadContext.descendants, + statusfilter.FilterContextThread, + filters, + mutes, + ) - return apiContext, nil + return &apiContext, nil } // WebContextGet is like ContextGet, but is explicitly diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go index 90cb1fed3..b071bd72e 100644 --- a/internal/processing/workers/surfacetimeline.go +++ b/internal/processing/workers/surfacetimeline.go @@ -384,8 +384,9 @@ func (s *Surface) timelineStatus( ) (bool, error) { // Ingest status into given timeline using provided function. - if inserted, err := ingest(ctx, timelineID, status); err != nil { - err = gtserror.Newf("error ingesting status %s: %w", status.ID, err) + if inserted, err := ingest(ctx, timelineID, status); err != nil && + !errors.Is(err, statusfilter.ErrHideStatus) { + err := gtserror.Newf("error ingesting status %s: %w", status.ID, err) return false, err } else if !inserted { // Nothing more to do. @@ -400,15 +401,19 @@ func (s *Surface) timelineStatus( filters, mutes, ) - if err != nil { - err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) + if err != nil && !errors.Is(err, statusfilter.ErrHideStatus) { + err := gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return true, err } - // The status was inserted so stream it to the user. - s.Stream.Update(ctx, account, apiStatus, streamType) + if apiStatus != nil { + // The status was inserted so stream it to the user. + s.Stream.Update(ctx, account, apiStatus, streamType) + return true, nil + } - return true, nil + // Status was hidden. + return false, nil } // timelineAndNotifyStatusForTagFollowers inserts the status into the