diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index 3b21c7ebe..7ade78166 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -27,7 +27,6 @@ import ( "syscall" "time" - "codeberg.org/gruf/go-sched" "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/internal/api" @@ -123,16 +122,21 @@ var Start action.GTSAction = func(ctx context.Context) error { // Add a task to the scheduler to sweep caches. // Frequency = 1 * minute // Threshold = 80% capacity - sweep := func(time.Time) { state.Caches.Sweep(80) } - job := sched.NewJob(sweep).Every(time.Minute) - _ = state.Workers.Scheduler.Schedule(job) + _ = state.Workers.Scheduler.AddRecurring( + "@cachesweep", // id + time.Time{}, // start + time.Minute, // freq + func(context.Context, time.Time) { + state.Caches.Sweep(80) + }, + ) // Build handlers used in later initializations. mediaManager := media.NewManager(&state) oauthServer := oauth.New(ctx, dbService) typeConverter := typeutils.NewConverter(&state) filter := visibility.NewFilter(&state) - federatingDB := federatingdb.New(&state, typeConverter) + federatingDB := federatingdb.New(&state, typeConverter, filter) transportController := transport.NewController(&state, federatingDB, &federation.Clock{}, client) federator := federation.NewFederator(&state, federatingDB, transportController, typeConverter, mediaManager) diff --git a/internal/ap/extract.go b/internal/ap/extract.go index 74c4497b3..332550578 100644 --- a/internal/ap/extract.go +++ b/internal/ap/extract.go @@ -91,6 +91,105 @@ func ExtractActivityData(activity pub.Activity, rawJSON map[string]any) ([]TypeO } } +// ExtractAccountables extracts Accountable objects from a slice TypeOrIRI, returning extracted and remaining TypeOrIRIs. +func ExtractAccountables(arr []TypeOrIRI) ([]Accountable, []TypeOrIRI) { + var accounts []Accountable + + for i := 0; i < len(arr); i++ { + elem := arr[i] + + if elem.IsIRI() { + // skip IRIs + continue + } + + // Extract AS vocab type + // associated with elem. + t := elem.GetType() + + // Try cast AS type as Accountable. + account, ok := ToAccountable(t) + if !ok { + continue + } + + // Add casted accountable type. + accounts = append(accounts, account) + + // Drop elem from slice. + copy(arr[:i], arr[i+1:]) + arr = arr[:len(arr)-1] + } + + return accounts, arr +} + +// ExtractStatusables extracts Statusable objects from a slice TypeOrIRI, returning extracted and remaining TypeOrIRIs. +func ExtractStatusables(arr []TypeOrIRI) ([]Statusable, []TypeOrIRI) { + var statuses []Statusable + + for i := 0; i < len(arr); i++ { + elem := arr[i] + + if elem.IsIRI() { + // skip IRIs + continue + } + + // Extract AS vocab type + // associated with elem. + t := elem.GetType() + + // Try cast AS type as Statusable. + status, ok := ToStatusable(t) + if !ok { + continue + } + + // Add casted Statusable type. + statuses = append(statuses, status) + + // Drop elem from slice. + copy(arr[:i], arr[i+1:]) + arr = arr[:len(arr)-1] + } + + return statuses, arr +} + +// ExtractPollOptionables extracts PollOptionable objects from a slice TypeOrIRI, returning extracted and remaining TypeOrIRIs. +func ExtractPollOptionables(arr []TypeOrIRI) ([]PollOptionable, []TypeOrIRI) { + var options []PollOptionable + + for i := 0; i < len(arr); i++ { + elem := arr[i] + + if elem.IsIRI() { + // skip IRIs + continue + } + + // Extract AS vocab type + // associated with elem. + t := elem.GetType() + + // Try cast as PollOptionable. + option, ok := ToPollOptionable(t) + if !ok { + continue + } + + // Add casted PollOptionable type. + options = append(options, option) + + // Drop elem from slice. + copy(arr[:i], arr[i+1:]) + arr = arr[:len(arr)-1] + } + + return options, arr +} + // ExtractPreferredUsername returns a string representation of // an interface's preferredUsername property. Will return an // error if preferredUsername is nil, not a string, or empty. @@ -192,7 +291,7 @@ func ExtractToURIs(i WithTo) []*url.URL { // ExtractCcURIs returns a slice of URIs // that the given WithCC addresses as Cc. -func ExtractCcURIs(i WithCC) []*url.URL { +func ExtractCcURIs(i WithCc) []*url.URL { ccProp := i.GetActivityStreamsCc() if ccProp == nil { return nil diff --git a/internal/ap/interfaces.go b/internal/ap/interfaces.go index 6ba3c3735..fed69d69d 100644 --- a/internal/ap/interfaces.go +++ b/internal/ap/interfaces.go @@ -23,6 +23,21 @@ import ( "github.com/superseriousbusiness/activity/streams/vocab" ) +// IsActivityable returns whether AS vocab type name is acceptable as Activityable. +func IsActivityable(typeName string) bool { + return isActivity(typeName) || + isIntransitiveActivity(typeName) +} + +// ToActivityable safely tries to cast vocab.Type as Activityable, also checking for expected AS type names. +func ToActivityable(t vocab.Type) (Activityable, bool) { + activityable, ok := t.(Activityable) + if !ok || !IsActivityable(t.GetTypeName()) { + return nil, false + } + return activityable, true +} + // IsAccountable returns whether AS vocab type name is acceptable as Accountable. func IsAccountable(typeName string) bool { switch typeName { @@ -88,6 +103,43 @@ func ToPollable(t vocab.Type) (Pollable, bool) { return pollable, true } +// IsPollOptionable returns whether AS vocab type name is acceptable as PollOptionable. +func IsPollOptionable(typeName string) bool { + return typeName == ObjectNote +} + +// ToPollOptionable safely tries to cast vocab.Type as PollOptionable, also checking for expected AS type names. +func ToPollOptionable(t vocab.Type) (PollOptionable, bool) { + note, ok := t.(vocab.ActivityStreamsNote) + if !ok || !IsPollOptionable(t.GetTypeName()) { + return nil, false + } + if note.GetActivityStreamsContent() != nil || + note.GetActivityStreamsName() == nil { + // A PollOption is an ActivityStreamsNote + // WITHOUT a content property, instead only + // a name property. + return nil, false + } + return note, true +} + +// Activityable represents the minimum activitypub interface for representing an 'activity'. +// (see: IsActivityable() for types implementing this, though you MUST make sure to check +// the typeName as this bare interface may be implementable by non-Activityable types). +type Activityable interface { + // Activity is also a vocab.Type + vocab.Type + + WithTo + WithCc + WithBcc + WithAttributedTo + WithActor + WithObject + WithPublished +} + // Accountable represents the minimum activitypub interface for representing an 'account'. // (see: IsAccountable() for types implementing this, though you MUST make sure to check // the typeName as this bare interface may be implementable by non-Accountable types). @@ -126,7 +178,7 @@ type Statusable interface { WithURL WithAttributedTo WithTo - WithCC + WithCc WithSensitive WithConversation WithContent @@ -145,16 +197,21 @@ type Pollable interface { WithClosed WithVotersCount - // base-interface + // base-interfaces Statusable } -// PollOptionable represents the minimum activitypub interface for representing a poll 'option'. -// (see: IsPollOptionable() for types implementing this). +// PollOptionable represents the minimum activitypub interface for representing a poll 'vote'. +// (see: IsPollOptionable() for types implementing this, though you MUST make sure to check +// the typeName as this bare interface may be implementable by non-Pollable types). type PollOptionable interface { - WithTypeName + vocab.Type + WithName + WithTo + WithInReplyTo WithReplies + WithAttributedTo } // Attachmentable represents the minimum activitypub interface for representing a 'mediaAttachment'. (see: IsAttachmentable). @@ -226,13 +283,13 @@ type Announceable interface { WithObject WithPublished WithTo - WithCC + WithCc } // Addressable represents the minimum interface for an addressed activity. type Addressable interface { WithTo - WithCC + WithCc } // ReplyToable represents the minimum interface for an Activity that can be InReplyTo another activity. @@ -268,6 +325,15 @@ type TypeOrIRI interface { WithType } +// Property represents the minimum interface for an ActivityStreams property with IRIs. +type Property[T TypeOrIRI] interface { + Len() int + At(int) T + + AppendIRI(*url.URL) + SetIRI(int, *url.URL) +} + // WithJSONLDId represents an activity with JSONLDIdProperty. type WithJSONLDId interface { GetJSONLDId() vocab.JSONLDIdProperty @@ -386,18 +452,24 @@ type WithTo interface { SetActivityStreamsTo(vocab.ActivityStreamsToProperty) } +// WithCC represents an activity with ActivityStreamsCcProperty +type WithCc interface { + GetActivityStreamsCc() vocab.ActivityStreamsCcProperty + SetActivityStreamsCc(vocab.ActivityStreamsCcProperty) +} + +// WithCC represents an activity with ActivityStreamsBccProperty +type WithBcc interface { + GetActivityStreamsBcc() vocab.ActivityStreamsBccProperty + SetActivityStreamsBcc(vocab.ActivityStreamsBccProperty) +} + // WithInReplyTo represents an activity with ActivityStreamsInReplyToProperty type WithInReplyTo interface { GetActivityStreamsInReplyTo() vocab.ActivityStreamsInReplyToProperty SetActivityStreamsInReplyTo(vocab.ActivityStreamsInReplyToProperty) } -// WithCC represents an activity with ActivityStreamsCcProperty -type WithCC interface { - GetActivityStreamsCc() vocab.ActivityStreamsCcProperty - SetActivityStreamsCc(vocab.ActivityStreamsCcProperty) -} - // WithSensitive represents an activity with ActivityStreamsSensitiveProperty type WithSensitive interface { GetActivityStreamsSensitive() vocab.ActivityStreamsSensitiveProperty diff --git a/internal/ap/properties.go b/internal/ap/properties.go new file mode 100644 index 000000000..d8441003f --- /dev/null +++ b/internal/ap/properties.go @@ -0,0 +1,325 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package ap + +import ( + "fmt" + "net/url" + "time" + + "github.com/superseriousbusiness/activity/streams" + "github.com/superseriousbusiness/activity/streams/vocab" +) + +// MustGet performs the given 'Get$Property(with) (T, error)' signature function, panicking on error. +// func MustGet[W, T any](fn func(W) (T, error), with W) T { +// t, err := fn(with) +// if err != nil { +// panicfAt(3, "error getting property on %T: %w", with, err) +// } +// return t +// } + +// MustSet performs the given 'Set$Property(with, T) error' signature function, panicking on error. +// func MustSet[W, T any](fn func(W, T) error, with W, value T) { +// err := fn(with, value) +// if err != nil { +// panicfAt(3, "error setting property on %T: %w", with, err) +// } +// } + +// AppendSet performs the given 'Append$Property(with, ...T) error' signature function, panicking on error. +// func MustAppend[W, T any](fn func(W, ...T) error, with W, values ...T) { +// err := fn(with, values...) +// if err != nil { +// panicfAt(3, "error appending properties on %T: %w", with, err) +// } +// } + +// GetJSONLDId returns the ID of 'with', or nil. +func GetJSONLDId(with WithJSONLDId) *url.URL { + idProp := with.GetJSONLDId() + if idProp == nil { + return nil + } + id := idProp.Get() + if id == nil { + return nil + } + return id +} + +// SetJSONLDId sets the given URL to the JSONLD ID of 'with'. +func SetJSONLDId(with WithJSONLDId, id *url.URL) { + idProp := with.GetJSONLDId() + if idProp == nil { + idProp = streams.NewJSONLDIdProperty() + } + idProp.SetIRI(id) + with.SetJSONLDId(idProp) +} + +// SetJSONLDIdStr sets the given string to the JSONLDID of 'with'. Returns error +func SetJSONLDIdStr(with WithJSONLDId, id string) error { + u, err := url.Parse(id) + if err != nil { + return fmt.Errorf("error parsing id url: %w", err) + } + SetJSONLDId(with, u) + return nil +} + +// GetTo returns the IRIs contained in the To property of 'with'. Panics on entries with missing ID. +func GetTo(with WithTo) []*url.URL { + toProp := with.GetActivityStreamsTo() + return getIRIs[vocab.ActivityStreamsToPropertyIterator](toProp) +} + +// AppendTo appends the given IRIs to the To property of 'with'. +func AppendTo(with WithTo, to ...*url.URL) { + appendIRIs(func() Property[vocab.ActivityStreamsToPropertyIterator] { + toProp := with.GetActivityStreamsTo() + if toProp == nil { + toProp = streams.NewActivityStreamsToProperty() + with.SetActivityStreamsTo(toProp) + } + return toProp + }, to...) +} + +// GetCc returns the IRIs contained in the Cc property of 'with'. Panics on entries with missing ID. +func GetCc(with WithCc) []*url.URL { + ccProp := with.GetActivityStreamsCc() + return getIRIs[vocab.ActivityStreamsCcPropertyIterator](ccProp) +} + +// AppendCc appends the given IRIs to the Cc property of 'with'. +func AppendCc(with WithCc, cc ...*url.URL) { + appendIRIs(func() Property[vocab.ActivityStreamsCcPropertyIterator] { + ccProp := with.GetActivityStreamsCc() + if ccProp == nil { + ccProp = streams.NewActivityStreamsCcProperty() + with.SetActivityStreamsCc(ccProp) + } + return ccProp + }, cc...) +} + +// GetBcc returns the IRIs contained in the Bcc property of 'with'. Panics on entries with missing ID. +func GetBcc(with WithBcc) []*url.URL { + bccProp := with.GetActivityStreamsBcc() + return getIRIs[vocab.ActivityStreamsBccPropertyIterator](bccProp) +} + +// AppendBcc appends the given IRIs to the Bcc property of 'with'. +func AppendBcc(with WithBcc, bcc ...*url.URL) { + appendIRIs(func() Property[vocab.ActivityStreamsBccPropertyIterator] { + bccProp := with.GetActivityStreamsBcc() + if bccProp == nil { + bccProp = streams.NewActivityStreamsBccProperty() + with.SetActivityStreamsBcc(bccProp) + } + return bccProp + }, bcc...) +} + +// GetActor returns the IRIs contained in the Actor property of 'with'. Panics on entries with missing ID. +func GetActor(with WithActor) []*url.URL { + actorProp := with.GetActivityStreamsActor() + return getIRIs[vocab.ActivityStreamsActorPropertyIterator](actorProp) +} + +// AppendActor appends the given IRIs to the Actor property of 'with'. +func AppendActor(with WithActor, actor ...*url.URL) { + appendIRIs(func() Property[vocab.ActivityStreamsActorPropertyIterator] { + actorProp := with.GetActivityStreamsActor() + if actorProp == nil { + actorProp = streams.NewActivityStreamsActorProperty() + with.SetActivityStreamsActor(actorProp) + } + return actorProp + }, actor...) +} + +// GetAttributedTo returns the IRIs contained in the AttributedTo property of 'with'. Panics on entries with missing ID. +func GetAttributedTo(with WithAttributedTo) []*url.URL { + attribProp := with.GetActivityStreamsAttributedTo() + return getIRIs[vocab.ActivityStreamsAttributedToPropertyIterator](attribProp) +} + +// AppendAttributedTo appends the given IRIs to the AttributedTo property of 'with'. +func AppendAttributedTo(with WithAttributedTo, attribTo ...*url.URL) { + appendIRIs(func() Property[vocab.ActivityStreamsAttributedToPropertyIterator] { + attribProp := with.GetActivityStreamsAttributedTo() + if attribProp == nil { + attribProp = streams.NewActivityStreamsAttributedToProperty() + with.SetActivityStreamsAttributedTo(attribProp) + } + return attribProp + }, attribTo...) +} + +// GetInReplyTo returns the IRIs contained in the InReplyTo property of 'with'. Panics on entries with missing ID. +func GetInReplyTo(with WithInReplyTo) []*url.URL { + replyProp := with.GetActivityStreamsInReplyTo() + return getIRIs[vocab.ActivityStreamsInReplyToPropertyIterator](replyProp) +} + +// AppendInReplyTo appends the given IRIs to the InReplyTo property of 'with'. +func AppendInReplyTo(with WithInReplyTo, replyTo ...*url.URL) { + appendIRIs(func() Property[vocab.ActivityStreamsInReplyToPropertyIterator] { + replyProp := with.GetActivityStreamsInReplyTo() + if replyProp == nil { + replyProp = streams.NewActivityStreamsInReplyToProperty() + with.SetActivityStreamsInReplyTo(replyProp) + } + return replyProp + }, replyTo...) +} + +// GetPublished returns the time contained in the Published property of 'with'. +func GetPublished(with WithPublished) time.Time { + publishProp := with.GetActivityStreamsPublished() + if publishProp == nil { + return time.Time{} + } + return publishProp.Get() +} + +// SetPublished sets the given time on the Published property of 'with'. +func SetPublished(with WithPublished, published time.Time) { + publishProp := with.GetActivityStreamsPublished() + if publishProp == nil { + publishProp = streams.NewActivityStreamsPublishedProperty() + with.SetActivityStreamsPublished(publishProp) + } + publishProp.Set(published) +} + +// GetEndTime returns the time contained in the EndTime property of 'with'. +func GetEndTime(with WithEndTime) time.Time { + endTimeProp := with.GetActivityStreamsEndTime() + if endTimeProp == nil { + return time.Time{} + } + return endTimeProp.Get() +} + +// SetEndTime sets the given time on the EndTime property of 'with'. +func SetEndTime(with WithEndTime, end time.Time) { + endTimeProp := with.GetActivityStreamsEndTime() + if endTimeProp == nil { + endTimeProp = streams.NewActivityStreamsEndTimeProperty() + with.SetActivityStreamsEndTime(endTimeProp) + } + endTimeProp.Set(end) +} + +// GetEndTime returns the times contained in the Closed property of 'with'. +func GetClosed(with WithClosed) []time.Time { + closedProp := with.GetActivityStreamsClosed() + if closedProp == nil || closedProp.Len() == 0 { + return nil + } + closed := make([]time.Time, 0, closedProp.Len()) + for i := 0; i < closedProp.Len(); i++ { + at := closedProp.At(i) + if t := at.GetXMLSchemaDateTime(); !t.IsZero() { + closed = append(closed, t) + } + } + return closed +} + +// AppendClosed appends the given times to the Closed property of 'with'. +func AppendClosed(with WithClosed, closed ...time.Time) { + if len(closed) == 0 { + return + } + closedProp := with.GetActivityStreamsClosed() + if closedProp == nil { + closedProp = streams.NewActivityStreamsClosedProperty() + with.SetActivityStreamsClosed(closedProp) + } + for _, closed := range closed { + closedProp.AppendXMLSchemaDateTime(closed) + } +} + +// GetVotersCount returns the integer contained in the VotersCount property of 'with', if found. +func GetVotersCount(with WithVotersCount) int { + votersProp := with.GetTootVotersCount() + if votersProp == nil { + return 0 + } + return votersProp.Get() +} + +// SetVotersCount sets the given count on the VotersCount property of 'with'. +func SetVotersCount(with WithVotersCount, count int) { + votersProp := with.GetTootVotersCount() + if votersProp == nil { + votersProp = streams.NewTootVotersCountProperty() + with.SetTootVotersCount(votersProp) + } + votersProp.Set(count) +} + +func getIRIs[T TypeOrIRI](prop Property[T]) []*url.URL { + if prop == nil || prop.Len() == 0 { + return nil + } + ids := make([]*url.URL, 0, prop.Len()) + for i := 0; i < prop.Len(); i++ { + at := prop.At(i) + if t := at.GetType(); t != nil { + id := GetJSONLDId(t) + if id != nil { + ids = append(ids, id) + continue + } + } + if at.IsIRI() { + id := at.GetIRI() + if id != nil { + ids = append(ids, id) + continue + } + } + } + return ids +} + +func appendIRIs[T TypeOrIRI](getProp func() Property[T], iri ...*url.URL) { + if len(iri) == 0 { + return + } + prop := getProp() + if prop == nil { + // check outside loop. + panic("prop not set") + } + for _, iri := range iri { + prop.AppendIRI(iri) + } +} + +// panicfAt panics with a call to gtserror.NewfAt() with given args (+1 to calldepth). +// func panicfAt(calldepth int, msg string, args ...any) { +// panic(gtserror.NewfAt(calldepth+1, msg, args...)) +// } diff --git a/internal/cleaner/cleaner.go b/internal/cleaner/cleaner.go index 1139a85bb..a1209ae08 100644 --- a/internal/cleaner/cleaner.go +++ b/internal/cleaner/cleaner.go @@ -22,8 +22,6 @@ import ( "errors" "time" - "codeberg.org/gruf/go-runners" - "codeberg.org/gruf/go-sched" "codeberg.org/gruf/go-store/v2/storage" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -150,26 +148,27 @@ func (c *Cleaner) ScheduleJobs() error { firstCleanupAt = firstCleanupAt.Add(cleanupEvery) } - // Get ctx associated with scheduler run state. - done := c.state.Workers.Scheduler.Done() - doneCtx := runners.CancelCtx(done) - - // TODO: we'll need to do some thinking to make these - // jobs restartable if we want to implement reloads in - // the future that make call to Workers.Stop() -> Workers.Start(). + fn := func(ctx context.Context, start time.Time) { + log.Info(ctx, "starting media clean") + c.Media().All(ctx, config.GetMediaRemoteCacheDays()) + c.Emoji().All(ctx, config.GetMediaRemoteCacheDays()) + log.Infof(ctx, "finished media clean after %s", time.Since(start)) + } log.Infof(nil, "scheduling media clean to run every %s, starting from %s; next clean will run at %s", cleanupEvery, cleanupFromStr, firstCleanupAt, ) - // Schedule the cleaning tasks to execute according to given schedule. - c.state.Workers.Scheduler.Schedule(sched.NewJob(func(start time.Time) { - log.Info(nil, "starting media clean") - c.Media().All(doneCtx, config.GetMediaRemoteCacheDays()) - c.Emoji().All(doneCtx, config.GetMediaRemoteCacheDays()) - log.Infof(nil, "finished media clean after %s", time.Since(start)) - }).EveryAt(firstCleanupAt, cleanupEvery)) + // Schedule the cleaning to execute according to schedule. + if !c.state.Workers.Scheduler.AddRecurring( + "@mediacleanup", + firstCleanupAt, + cleanupEvery, + fn, + ) { + panic("failed to schedule @mediacleanup") + } return nil } diff --git a/internal/cleaner/cleaner_test.go b/internal/cleaner/cleaner_test.go index d23dac504..4524e9609 100644 --- a/internal/cleaner/cleaner_test.go +++ b/internal/cleaner/cleaner_test.go @@ -48,7 +48,7 @@ func (suite *CleanerTestSuite) SetupTest() { suite.state.Caches.Init() // Ensure scheduler started (even if unused). - suite.state.Workers.Scheduler.Start(nil) + suite.state.Workers.Scheduler.Start() // Initialize test database. _ = testrig.NewTestDB(&suite.state) @@ -58,6 +58,7 @@ func (suite *CleanerTestSuite) SetupTest() { suite.state.Storage = testrig.NewInMemoryStorage() // Initialize test cleaner instance. + testrig.StartWorkers(&suite.state) suite.cleaner = cleaner.New(&suite.state) // Allocate new test model emojis. @@ -66,6 +67,7 @@ func (suite *CleanerTestSuite) SetupTest() { func (suite *CleanerTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.state.DB) + testrig.StopWorkers(&suite.state) } // mapvals extracts a slice of values from the values contained within the map. diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index 58f07f9cd..a4e74de3c 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -945,7 +945,7 @@ func (d *Dereferencer) dereferenceAccountFeatured(ctx context.Context, requestUs // we still know it was *meant* to be pinned. statusURIs = append(statusURIs, statusURI) - status, _, err := d.getStatusByURI(ctx, requestUser, statusURI) + status, _, _, err := d.getStatusByURI(ctx, requestUser, statusURI) if err != nil { // We couldn't get the status, bummer. Just log + move on, we can try later. log.Errorf(ctx, "error getting status from featured collection %s: %v", statusURI, err) diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index 712692814..4dd6d3baf 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -22,9 +22,8 @@ import ( "errors" "io" "net/url" - "time" - "slices" + "time" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -55,12 +54,12 @@ func statusUpToDate(status *gtsmodel.Status) bool { return false } -// GetStatusByURI will attempt to fetch a status by its URI, first checking the database. In the case of a newly-met remote model, or a remote model -// whose last_fetched date is beyond a certain interval, the status will be dereferenced. In the case of dereferencing, some low-priority status information -// may be enqueued for asynchronous fetching, e.g. dereferencing the remainder of the status thread. An ActivityPub object indicates the status was dereferenced. +// GetStatusByURI will attempt to fetch a status by its URI, first checking the database. In the case of a newly-met remote model, or a remote model whose 'last_fetched' date +// is beyond a certain interval, the status will be dereferenced. In the case of dereferencing, some low-priority status information may be enqueued for asynchronous fetching, +// e.g. dereferencing the status thread. Param 'syncParent' = true indicates to fetch status ancestors synchronously. An ActivityPub object indicates the status was dereferenced. func (d *Dereferencer) GetStatusByURI(ctx context.Context, requestUser string, uri *url.URL) (*gtsmodel.Status, ap.Statusable, error) { // Fetch and dereference status if necessary. - status, apubStatus, err := d.getStatusByURI(ctx, + status, statusable, isNew, err := d.getStatusByURI(ctx, requestUser, uri, ) @@ -68,18 +67,22 @@ func (d *Dereferencer) GetStatusByURI(ctx context.Context, requestUser string, u return nil, nil, err } - if apubStatus != nil { - // This status was updated, enqueue re-dereferencing the whole thread. - d.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) { - d.dereferenceThread(ctx, requestUser, uri, status, apubStatus) - }) + if statusable != nil { + // Deref parents + children. + d.dereferenceThread(ctx, + requestUser, + uri, + status, + statusable, + isNew, + ) } - return status, apubStatus, nil + return status, statusable, nil } // getStatusByURI is a package internal form of .GetStatusByURI() that doesn't bother dereferencing the whole thread on update. -func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, uri *url.URL) (*gtsmodel.Status, ap.Statusable, error) { +func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, uri *url.URL) (*gtsmodel.Status, ap.Statusable, bool, error) { var ( status *gtsmodel.Status uriStr = uri.String() @@ -94,7 +97,7 @@ func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, u uriStr, ) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, nil, gtserror.Newf("error checking database for status %s by uri: %w", uriStr, err) + return nil, nil, false, gtserror.Newf("error checking database for status %s by uri: %w", uriStr, err) } if status == nil { @@ -104,14 +107,14 @@ func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, u uriStr, ) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, nil, gtserror.Newf("error checking database for status %s by url: %w", uriStr, err) + return nil, nil, false, gtserror.Newf("error checking database for status %s by url: %w", uriStr, err) } } if status == nil { // Ensure that this isn't a search for a local status. if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() { - return nil, nil, gtserror.SetUnretrievable(err) // this will be db.ErrNoEntries + return nil, nil, false, gtserror.SetUnretrievable(err) // this will be db.ErrNoEntries } // Create and pass-through a new bare-bones model for deref. @@ -127,11 +130,11 @@ func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, u if err := d.state.DB.PopulateStatus(ctx, status); err != nil { log.Errorf(ctx, "error populating existing status: %v", err) } - return status, nil, nil + return status, nil, false, nil } // Try to update + deref existing status model. - latest, apubStatus, err := d.enrichStatusSafely(ctx, + latest, statusable, isNew, err := d.enrichStatusSafely(ctx, requestUser, uri, status, @@ -140,17 +143,22 @@ func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, u if err != nil { log.Errorf(ctx, "error enriching remote status: %v", err) - // Fallback to existing. - return status, nil, nil + // Fallback to existing status. + return status, nil, false, nil } - return latest, apubStatus, nil + return latest, statusable, isNew, nil } -// RefreshStatus updates the given status if remote and last_fetched is beyond fetch interval, or if force is set. An updated status model is returned, -// but in the case of dereferencing, some low-priority status information may be enqueued for asynchronous fetching, e.g. dereferencing the remainder of the -// status thread. An ActivityPub object indicates the status was dereferenced (i.e. updated). -func (d *Dereferencer) RefreshStatus(ctx context.Context, requestUser string, status *gtsmodel.Status, apubStatus ap.Statusable, force bool) (*gtsmodel.Status, ap.Statusable, error) { +// RefreshStatus is functionally equivalent to GetStatusByURI(), except that it requires a pre +// populated status model (with AT LEAST uri set), and ALL thread dereferencing is asynchronous. +func (d *Dereferencer) RefreshStatus( + ctx context.Context, + requestUser string, + status *gtsmodel.Status, + statusable ap.Statusable, + force bool, +) (*gtsmodel.Status, ap.Statusable, error) { // Check whether needs update. if !force && statusUpToDate(status) { return status, nil, nil @@ -162,28 +170,40 @@ func (d *Dereferencer) RefreshStatus(ctx context.Context, requestUser string, st return nil, nil, gtserror.Newf("invalid status uri %q: %w", status.URI, err) } - // Try to update + deref the passed status model. - latest, apubStatus, err := d.enrichStatusSafely(ctx, + // Try to update + dereference the passed status model. + latest, statusable, isNew, err := d.enrichStatusSafely(ctx, requestUser, uri, status, - apubStatus, + statusable, ) if err != nil { return nil, nil, err } - // This status was updated, enqueue re-dereferencing the whole thread. - d.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) { - d.dereferenceThread(ctx, requestUser, uri, latest, apubStatus) - }) + if statusable != nil { + // Deref parents + children. + d.dereferenceThread(ctx, + requestUser, + uri, + status, + statusable, + isNew, + ) + } - return latest, apubStatus, nil + return latest, statusable, nil } -// RefreshStatusAsync enqueues the given status for an asychronous update fetching, if last_fetched is beyond fetch interval, or if force is set. -// This is a more optimized form of manually enqueueing .UpdateStatus() to the federation worker, since it only enqueues update if necessary. -func (d *Dereferencer) RefreshStatusAsync(ctx context.Context, requestUser string, status *gtsmodel.Status, apubStatus ap.Statusable, force bool) { +// RefreshStatusAsync is functionally equivalent to RefreshStatus(), except that ALL +// dereferencing is queued for asynchronous processing, (both thread AND status). +func (d *Dereferencer) RefreshStatusAsync( + ctx context.Context, + requestUser string, + status *gtsmodel.Status, + statusable ap.Statusable, + force bool, +) { // Check whether needs update. if !force && statusUpToDate(status) { return @@ -196,17 +216,25 @@ func (d *Dereferencer) RefreshStatusAsync(ctx context.Context, requestUser strin return } - // Enqueue a worker function to re-fetch this status async. + // Enqueue a worker function to re-fetch this status entirely async. d.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) { - latest, apubStatus, err := d.enrichStatusSafely(ctx, requestUser, uri, status, apubStatus) + latest, statusable, _, err := d.enrichStatusSafely(ctx, + requestUser, + uri, + status, + statusable, + ) if err != nil { log.Errorf(ctx, "error enriching remote status: %v", err) return } - - if apubStatus != nil { - // This status was updated, re-dereference the whole thread. - d.dereferenceThread(ctx, requestUser, uri, latest, apubStatus) + if statusable != nil { + if err := d.DereferenceStatusAncestors(ctx, requestUser, latest); err != nil { + log.Error(ctx, err) + } + if err := d.DereferenceStatusDescendants(ctx, requestUser, uri, statusable); err != nil { + log.Error(ctx, err) + } } }) } @@ -220,7 +248,7 @@ func (d *Dereferencer) enrichStatusSafely( uri *url.URL, status *gtsmodel.Status, apubStatus ap.Statusable, -) (*gtsmodel.Status, ap.Statusable, error) { +) (*gtsmodel.Status, ap.Statusable, bool, error) { uriStr := status.URI if status.ID != "" { @@ -238,6 +266,9 @@ func (d *Dereferencer) enrichStatusSafely( unlock = doOnce(unlock) defer unlock() + // This is a NEW status (to us). + isNew := (status.ID == "") + // Perform status enrichment with passed vars. latest, apubStatus, err := d.enrichStatus(ctx, requestUser, @@ -261,6 +292,7 @@ func (d *Dereferencer) enrichStatusSafely( // otherwise this indicates WE // enriched the status. apubStatus = nil + isNew = false // DATA RACE! We likely lost out to another goroutine // in a call to db.Put(Status). Look again in DB by URI. @@ -270,7 +302,7 @@ func (d *Dereferencer) enrichStatusSafely( } } - return latest, apubStatus, err + return latest, apubStatus, isNew, err } // enrichStatus will enrich the given status, whether a new @@ -343,6 +375,7 @@ func (d *Dereferencer) enrichStatus( } // Carry-over values and set fetch time. + latestStatus.UpdatedAt = status.UpdatedAt latestStatus.FetchedAt = time.Now() latestStatus.Local = status.Local diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go index 5753ce4dd..0ad8f09e4 100644 --- a/internal/federation/dereferencing/thread.go +++ b/internal/federation/dereferencing/thread.go @@ -38,15 +38,42 @@ import ( // ancesters we are willing to follow before returning error. const maxIter = 1000 -func (d *Dereferencer) dereferenceThread(ctx context.Context, username string, statusIRI *url.URL, status *gtsmodel.Status, statusable ap.Statusable) { - // Ensure that ancestors have been fully dereferenced - if err := d.DereferenceStatusAncestors(ctx, username, status); err != nil { - log.Error(ctx, err) - } +// dereferenceThread handles dereferencing status thread after +// fetch. Passing off appropriate parts to be enqueued for async +// processing, or handling some parts synchronously when required. +func (d *Dereferencer) dereferenceThread( + ctx context.Context, + requestUser string, + uri *url.URL, + status *gtsmodel.Status, + statusable ap.Statusable, + isNew bool, +) { + if isNew { + // This is a new status that we need the ancestors of in + // order to determine visibility. Perform the initial part + // of thread dereferencing, i.e. parents, synchronously. + err := d.DereferenceStatusAncestors(ctx, requestUser, status) + if err != nil { + log.Error(ctx, err) + } - // Ensure that descendants have been fully dereferenced - if err := d.DereferenceStatusDescendants(ctx, username, statusIRI, statusable); err != nil { - log.Error(ctx, err) + // Enqueue dereferencing remaining status thread, (children), asychronously . + d.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) { + if err := d.DereferenceStatusDescendants(ctx, requestUser, uri, statusable); err != nil { + log.Error(ctx, err) + } + }) + } else { + // This is an existing status, dereference the WHOLE thread asynchronously. + d.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) { + if err := d.DereferenceStatusAncestors(ctx, requestUser, status); err != nil { + log.Error(ctx, err) + } + if err := d.DereferenceStatusDescendants(ctx, requestUser, uri, statusable); err != nil { + log.Error(ctx, err) + } + }) } } @@ -157,7 +184,7 @@ func (d *Dereferencer) DereferenceStatusAncestors(ctx context.Context, username // - refetching recently fetched statuses (recursion!) // - remote domain is blocked (will return unretrievable) // - any http type error for a new status returns unretrievable - parent, _, err := d.getStatusByURI(ctx, username, inReplyToURI) + parent, _, _, err := d.getStatusByURI(ctx, username, inReplyToURI) if err == nil { // We successfully fetched the parent. // Update current status with new info. @@ -325,7 +352,7 @@ stackLoop: // - refetching recently fetched statuses (recursion!) // - remote domain is blocked (will return unretrievable) // - any http type error for a new status returns unretrievable - _, statusable, err := d.getStatusByURI(ctx, username, itemIRI) + _, statusable, _, err := d.getStatusByURI(ctx, username, itemIRI) if err != nil { if !gtserror.Unretrievable(err) { l.Errorf("error dereferencing remote status %s: %v", itemIRI, err) diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go index 14e846b15..0fb459190 100644 --- a/internal/federation/federatingdb/create.go +++ b/internal/federation/federatingdb/create.go @@ -24,7 +24,6 @@ import ( "strings" "codeberg.org/gruf/go-logger/v2/level" - "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -137,22 +136,37 @@ func (f *federatingDB) activityCreate( return gtserror.Newf("could not convert asType %T to ActivityStreamsCreate", asType) } - for _, object := range ap.ExtractObjects(create) { - // Try to get object as vocab.Type, - // else skip handling (likely) IRI. - objType := object.GetType() - if objType == nil { - continue - } + var errs gtserror.MultiError - if statusable, ok := ap.ToStatusable(objType); ok { - return f.createStatusable(ctx, statusable, receivingAccount, requestingAccount) - } + // Extract objects from create activity. + objects := ap.ExtractObjects(create) - // TODO: handle CREATE of other types? + // Extract Statusables from objects slice (this must be + // done AFTER extracting options due to how AS typing works). + statusables, objects := ap.ExtractStatusables(objects) + + for _, statusable := range statusables { + // Check if this is a forwarded object, i.e. did + // the account making the request also create this? + forwarded := !isSender(statusable, requestingAccount) + + // Handle create event for this statusable. + if err := f.createStatusable(ctx, + receivingAccount, + requestingAccount, + statusable, + forwarded, + ); err != nil { + errs.Appendf("error creating statusable: %w", err) + } } - return nil + if len(objects) > 0 { + // Log any unhandled objects after filtering for debug purposes. + log.Debugf(ctx, "unhandled CREATE types: %v", typeNames(objects)) + } + + return errs.Combine() } // createStatusable handles a Create activity for a Statusable. @@ -161,88 +175,36 @@ func (f *federatingDB) activityCreate( // the processor for further asynchronous processing. func (f *federatingDB) createStatusable( ctx context.Context, + receiver *gtsmodel.Account, + requester *gtsmodel.Account, statusable ap.Statusable, - receivingAccount *gtsmodel.Account, - requestingAccount *gtsmodel.Account, + forwarded bool, ) error { - // Statusable must have an attributedTo. - attrToProp := statusable.GetActivityStreamsAttributedTo() - if attrToProp == nil { - return gtserror.Newf("statusable had no attributedTo") - } - - // Statusable must have an ID. - idProp := statusable.GetJSONLDId() - if idProp == nil || !idProp.IsIRI() { - return gtserror.Newf("statusable had no id, or id was not a URI") - } - - statusableURI := idProp.GetIRI() - - // Check if we have a forward. In other words, was the - // statusable posted to our inbox by at least one actor - // who actually created it, or are they forwarding it? - forward := true - for iter := attrToProp.Begin(); iter != attrToProp.End(); iter = iter.Next() { - actorURI, err := pub.ToId(iter) - if err != nil { - return gtserror.Newf("error extracting id from attributedTo entry: %w", err) - } - - if requestingAccount.URI == actorURI.String() { - // The actor who posted this statusable to our inbox is - // (one of) its creator(s), so this is not a forward. - forward = false - break - } - } - - // Check if we already have a status entry - // for this statusable, based on the ID/URI. - statusableURIStr := statusableURI.String() - status, err := f.state.DB.GetStatusByURI(ctx, statusableURIStr) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return gtserror.Newf("db error checking existence of status %s: %w", statusableURIStr, err) - } - - if status != nil { - // We already had this status in the db, no need for further action. - log.Trace(ctx, "status already exists: %s", statusableURIStr) - return nil - } - // If we do have a forward, we should ignore the content // and instead deref based on the URI of the statusable. // // In other words, don't automatically trust whoever sent // this status to us, but fetch the authentic article from // the server it originated from. - if forward { - // Pass the statusable URI (APIri) into the processor worker - // and do the rest of the processing asynchronously. + if forwarded { + // Pass the statusable URI (APIri) into the processor + // worker and do the rest of the processing asynchronously. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, - APIri: statusableURI, + APIri: ap.GetJSONLDId(statusable), APObjectModel: nil, GTSModel: nil, - ReceivingAccount: receivingAccount, + ReceivingAccount: receiver, }) return nil } - // This is a non-forwarded status we can trust the requester on, - // convert this provided statusable data to a useable gtsmodel status. - status, err = f.converter.ASStatusToStatus(ctx, statusable) - if err != nil { - return gtserror.Newf("error converting statusable to status: %w", err) - } - // Check whether we should accept this new status. accept, err := f.shouldAcceptStatusable(ctx, - receivingAccount, - requestingAccount, - status, + receiver, + requester, + statusable, ) if err != nil { return gtserror.Newf("error checking status acceptibility: %w", err) @@ -258,65 +220,52 @@ func (f *federatingDB) createStatusable( return nil } - // ID the new status based on the time it was created. - status.ID, err = id.NewULIDFromTime(status.CreatedAt) - if err != nil { - return err - } - - // Put this newly parsed status in the database. - if err := f.state.DB.PutStatus(ctx, status); err != nil { - if errors.Is(err, db.ErrAlreadyExists) { - // The status already exists in the database, which - // means we've already processed it and some race - // condition means we didn't catch it yet. We can - // just return nil here and be done with it. - return nil - } - return gtserror.Newf("db error inserting status: %w", err) - } - // Do the rest of the processing asynchronously. The processor // will handle inserting/updating + further dereferencing the status. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, APIri: nil, + GTSModel: nil, APObjectModel: statusable, - GTSModel: status, - ReceivingAccount: receivingAccount, + ReceivingAccount: receiver, }) return nil } -func (f *federatingDB) shouldAcceptStatusable(ctx context.Context, receiver *gtsmodel.Account, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) { +func (f *federatingDB) shouldAcceptStatusable(ctx context.Context, receiver *gtsmodel.Account, requester *gtsmodel.Account, statusable ap.Statusable) (bool, error) { host := config.GetHost() accountDomain := config.GetAccountDomain() // Check whether status mentions the receiver, // this is the quickest check so perform it first. - // Prefer checking using mention Href, fall back to Name. - for _, mention := range status.Mentions { - targetURI := mention.TargetAccountURI - nameString := mention.NameString + mentions, _ := ap.ExtractMentions(statusable) + for _, mention := range mentions { - if targetURI != "" { - if targetURI == receiver.URI || targetURI == receiver.URL { - // Target URI or URL match; - // receiver is mentioned. + // Extract placeholder mention vars. + accURI := mention.TargetAccountURI + name := mention.NameString + + switch { + case accURI != "": + if accURI == receiver.URI || + accURI == receiver.URL { + // Mention target is receiver, + // they are mentioned in status. return true, nil } - } else if nameString != "" { - username, domain, err := util.ExtractNamestringParts(nameString) + + case accURI == "" && name != "": + // Only a name was provided, extract the user@domain parts. + user, domain, err := util.ExtractNamestringParts(name) if err != nil { - return false, gtserror.Newf("error checking if mentioned: %w", err) + return false, gtserror.Newf("error extracting mention name parts: %w", err) } - if (domain == host || domain == accountDomain) && - strings.EqualFold(username, receiver.Username) { - // Username + domain match; - // receiver is mentioned. + // Check if the name points to our receiving local user. + isLocal := (domain == host || domain == accountDomain) + if isLocal && strings.EqualFold(user, receiver.Username) { return true, nil } } diff --git a/internal/federation/federatingdb/create_test.go b/internal/federation/federatingdb/create_test.go index 6c18f5bd0..a1f1a7e18 100644 --- a/internal/federation/federatingdb/create_test.go +++ b/internal/federation/federatingdb/create_test.go @@ -39,6 +39,8 @@ func (suite *CreateTestSuite) TestCreateNote() { ctx := createTestContext(receivingAccount, requestingAccount) create := suite.testActivities["dm_for_zork"].Activity + objProp := create.GetActivityStreamsObject() + note := objProp.At(0).GetType().(ap.Statusable) err := suite.federatingDB.Create(ctx, create) suite.NoError(err) @@ -47,18 +49,7 @@ func (suite *CreateTestSuite) TestCreateNote() { msg := <-suite.fromFederator suite.Equal(ap.ObjectNote, msg.APObjectType) suite.Equal(ap.ActivityCreate, msg.APActivityType) - - // shiny new status should be defined on the message - suite.NotNil(msg.GTSModel) - status := msg.GTSModel.(*gtsmodel.Status) - - // status should have some expected values - suite.Equal(requestingAccount.ID, status.AccountID) - suite.Equal("@the_mighty_zork@localhost:8080 hey zork here's a new private note for you", status.Content) - - // status should be in the database - _, err = suite.db.GetStatusByID(context.Background(), status.ID) - suite.NoError(err) + suite.Equal(note, msg.APObjectModel) } func (suite *CreateTestSuite) TestCreateNoteForward() { @@ -78,7 +69,7 @@ func (suite *CreateTestSuite) TestCreateNoteForward() { suite.Equal(ap.ActivityCreate, msg.APActivityType) // nothing should be set as the model since this is a forward - suite.Nil(msg.GTSModel) + suite.Nil(msg.APObjectModel) // but we should have a uri set suite.Equal("http://example.org/users/Some_User/statuses/afaba698-5740-4e32-a702-af61aa543bc1", msg.APIri.String()) diff --git a/internal/federation/federatingdb/db.go b/internal/federation/federatingdb/db.go index 8e98dc2f2..75ef3a2a7 100644 --- a/internal/federation/federatingdb/db.go +++ b/internal/federation/federatingdb/db.go @@ -24,6 +24,7 @@ import ( "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" ) // DB wraps the pub.Database interface with a couple of custom functions for GoToSocial. @@ -33,7 +34,6 @@ type DB interface { Accept(ctx context.Context, accept vocab.ActivityStreamsAccept) error Reject(ctx context.Context, reject vocab.ActivityStreamsReject) error Announce(ctx context.Context, announce vocab.ActivityStreamsAnnounce) error - Question(ctx context.Context, question vocab.ActivityStreamsQuestion) error } // FederatingDB uses the underlying DB interface to implement the go-fed pub.Database interface. @@ -41,13 +41,15 @@ type DB interface { type federatingDB struct { state *state.State converter *typeutils.Converter + filter *visibility.Filter } // New returns a DB interface using the given database and config -func New(state *state.State, converter *typeutils.Converter) DB { +func New(state *state.State, converter *typeutils.Converter, filter *visibility.Filter) DB { fdb := federatingDB{ state: state, converter: converter, + filter: filter, } return &fdb } diff --git a/internal/federation/federatingdb/question.go b/internal/federation/federatingdb/question.go deleted file mode 100644 index 85226d9ed..000000000 --- a/internal/federation/federatingdb/question.go +++ /dev/null @@ -1,32 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package federatingdb - -import ( - "context" - - "github.com/superseriousbusiness/activity/streams/vocab" -) - -func (f *federatingDB) Question(ctx context.Context, question vocab.ActivityStreamsQuestion) error { - receivingAccount, requestingAccount, internal := extractFromCtx(ctx) - if internal { - return nil // Already processed. - } - return f.createStatusable(ctx, question, receivingAccount, requestingAccount) -} diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go index 5d3d4a0ff..26ea81f72 100644 --- a/internal/federation/federatingdb/update.go +++ b/internal/federation/federatingdb/update.go @@ -19,11 +19,13 @@ package federatingdb import ( "context" + "errors" "codeberg.org/gruf/go-logger/v2/level" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -71,18 +73,21 @@ func (f *federatingDB) updateAccountable(ctx context.Context, receivingAcct *gts // Extract AP URI of the updated Accountable model. idProp := accountable.GetJSONLDId() if idProp == nil || !idProp.IsIRI() { - return gtserror.New("Accountable id prop was nil or not IRI") + return gtserror.New("invalid id prop") } - updatedAcctURI := idProp.GetIRI() - // Don't try to update local accounts, it will break things. - if updatedAcctURI.Host == config.GetHost() { + // Get the account URI string for checks + accountURI := idProp.GetIRI() + accountURIStr := accountURI.String() + + // Don't try to update local accounts. + if accountURI.Host == config.GetHost() { return nil } - // Ensure Accountable and requesting account are one and the same. - if updatedAcctURIStr := updatedAcctURI.String(); requestingAcct.URI != updatedAcctURIStr { - return gtserror.Newf("update for %s was requested by %s, this is not valid", updatedAcctURIStr, requestingAcct.URI) + // Check that update was by the account themselves. + if accountURIStr != requestingAcct.URI { + return gtserror.Newf("update for %s was not requested by owner", accountURIStr) } // Pass in to the processor the existing version of the requesting @@ -117,15 +122,31 @@ func (f *federatingDB) updateStatusable(ctx context.Context, receivingAcct *gtsm return nil } + // Check if this is a forwarded object, i.e. did + // the account making the request also create this? + forwarded := !isSender(statusable, requestingAcct) + // Get the status we have on file for this URI string. status, err := f.state.DB.GetStatusByURI(ctx, statusURIStr) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNoEntries) { return gtserror.Newf("error fetching status from db: %w", err) } - // Check that update was by the status author. - if status.AccountID != requestingAcct.ID { - return gtserror.Newf("update for %s was not requested by author", statusURIStr) + if status == nil { + // We haven't seen this status before, be + // lenient and handle as a CREATE event. + return f.createStatusable(ctx, + receivingAcct, + requestingAcct, + statusable, + forwarded, + ) + } + + if forwarded { + // For forwarded updates, set a nil AS + // status to force refresh from remote. + statusable = nil } // Queue an UPDATE NOTE activity to our fedi API worker, @@ -134,7 +155,7 @@ func (f *federatingDB) updateStatusable(ctx context.Context, receivingAcct *gtsm APObjectType: ap.ObjectNote, APActivityType: ap.ActivityUpdate, GTSModel: status, // original status - APObjectModel: statusable, + APObjectModel: (ap.Statusable)(statusable), ReceivingAccount: receivingAcct, }) diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go index d46451e21..dd7a2240e 100644 --- a/internal/federation/federatingdb/util.go +++ b/internal/federation/federatingdb/util.go @@ -20,7 +20,6 @@ package federatingdb import ( "context" "encoding/json" - "errors" "fmt" "net/url" @@ -37,6 +36,30 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/uris" ) +func typeNames(objects []ap.TypeOrIRI) []string { + names := make([]string, len(objects)) + for i, object := range objects { + if object.IsIRI() { + names[i] = "IRI" + } else if t := object.GetType(); t != nil { + names[i] = t.GetTypeName() + } else { + names[i] = "nil" + } + } + return names +} + +// isSender returns whether an object with AttributedTo property comes from the given requesting account. +func isSender(with ap.WithAttributedTo, requester *gtsmodel.Account) bool { + for _, uri := range ap.GetAttributedTo(with) { + if uri.String() == requester.URI { + return true + } + } + return false +} + func sameActor(actor1 vocab.ActivityStreamsActorProperty, actor2 vocab.ActivityStreamsActorProperty) bool { if actor1 == nil || actor2 == nil { return false @@ -78,131 +101,31 @@ func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (idURL *url.URL, l.Debug("entering NewID") } - switch t.GetTypeName() { - case ap.ActivityFollow: - // FOLLOW - // ID might already be set on a follow we've created, so check it here and return it if it is - follow, ok := t.(vocab.ActivityStreamsFollow) - if !ok { - return nil, errors.New("newid: follow couldn't be parsed into vocab.ActivityStreamsFollow") - } - idProp := follow.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil - } - } - // it's not set so create one based on the actor set on the follow (ie., the followER not the followEE) - actorProp := follow.GetActivityStreamsActor() - if actorProp != nil { - for iter := actorProp.Begin(); iter != actorProp.End(); iter = iter.Next() { - // take the IRI of the first actor we can find (there should only be one) - if iter.IsIRI() { - // if there's an error here, just use the fallback behavior -- we don't need to return an error here - if actorAccount, err := f.state.DB.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil { - newID, err := id.NewRandomULID() - if err != nil { - return nil, err - } - return url.Parse(uris.GenerateURIForFollow(actorAccount.Username, newID)) - } + // Most of our types set an ID already + // by this point, return this if found. + idProp := t.GetJSONLDId() + if idProp != nil && idProp.IsIRI() { + return idProp.GetIRI(), nil + } + + if t.GetTypeName() == ap.ActivityFollow { + follow, _ := t.(vocab.ActivityStreamsFollow) + + // If an actor URI has been set, create a new ID + // based on actor (i.e. followER not the followEE). + if uri := ap.GetActor(follow); len(uri) == 1 { + if actorAccount, err := f.state.DB.GetAccountByURI(ctx, uri[0].String()); err == nil { + newID, err := id.NewRandomULID() + if err != nil { + return nil, err } - } - } - case ap.ObjectNote: - // NOTE aka STATUS - // ID might already be set on a note we've created, so check it here and return it if it is - note, ok := t.(vocab.ActivityStreamsNote) - if !ok { - return nil, errors.New("newid: note couldn't be parsed into vocab.ActivityStreamsNote") - } - idProp := note.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil - } - } - case ap.ActivityLike: - // LIKE aka FAVE - // ID might already be set on a fave we've created, so check it here and return it if it is - fave, ok := t.(vocab.ActivityStreamsLike) - if !ok { - return nil, errors.New("newid: fave couldn't be parsed into vocab.ActivityStreamsLike") - } - idProp := fave.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil - } - } - case ap.ActivityCreate: - // CREATE - // ID might already be set on a Create, so check it here and return it if it is - create, ok := t.(vocab.ActivityStreamsCreate) - if !ok { - return nil, errors.New("newid: create couldn't be parsed into vocab.ActivityStreamsCreate") - } - idProp := create.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil - } - } - case ap.ActivityAnnounce: - // ANNOUNCE aka BOOST - // ID might already be set on an announce we've created, so check it here and return it if it is - announce, ok := t.(vocab.ActivityStreamsAnnounce) - if !ok { - return nil, errors.New("newid: announce couldn't be parsed into vocab.ActivityStreamsAnnounce") - } - idProp := announce.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil - } - } - case ap.ActivityUpdate: - // UPDATE - // ID might already be set on an update we've created, so check it here and return it if it is - update, ok := t.(vocab.ActivityStreamsUpdate) - if !ok { - return nil, errors.New("newid: update couldn't be parsed into vocab.ActivityStreamsUpdate") - } - idProp := update.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil - } - } - case ap.ActivityBlock: - // BLOCK - // ID might already be set on a block we've created, so check it here and return it if it is - block, ok := t.(vocab.ActivityStreamsBlock) - if !ok { - return nil, errors.New("newid: block couldn't be parsed into vocab.ActivityStreamsBlock") - } - idProp := block.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil - } - } - case ap.ActivityUndo: - // UNDO - // ID might already be set on an undo we've created, so check it here and return it if it is - undo, ok := t.(vocab.ActivityStreamsUndo) - if !ok { - return nil, errors.New("newid: undo couldn't be parsed into vocab.ActivityStreamsUndo") - } - idProp := undo.GetJSONLDId() - if idProp != nil { - if idProp.IsIRI() { - return idProp.GetIRI(), nil + return url.Parse(uris.GenerateURIForFollow(actorAccount.Username, newID)) } } } - // fallback default behavior: just return a random ULID after our protocol and host + // Default fallback behaviour: + // {proto}://{host}/{randomID} newID, err := id.NewRandomULID() if err != nil { return nil, err diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index 28dc145af..5a913dbbe 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -522,9 +522,6 @@ func (f *Federator) FederatingCallbacks(ctx context.Context) (wrapped pub.Federa func(ctx context.Context, announce vocab.ActivityStreamsAnnounce) error { return f.FederatingDB().Announce(ctx, announce) }, - func(ctx context.Context, question vocab.ActivityStreamsQuestion) error { - return f.FederatingDB().Question(ctx, question) - }, } return diff --git a/internal/processing/search/get.go b/internal/processing/search/get.go index 30a2745af..4c09f05bb 100644 --- a/internal/processing/search/get.go +++ b/internal/processing/search/get.go @@ -603,7 +603,6 @@ func (p *Processor) statusByURI( requestingAccount.Username, uri, ) - return status, err } diff --git a/internal/processing/workers/fromfediapi.go b/internal/processing/workers/fromfediapi.go index f57235bf1..1ce3b6076 100644 --- a/internal/processing/workers/fromfediapi.go +++ b/internal/processing/workers/fromfediapi.go @@ -32,6 +32,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing/account" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" ) // fediAPI wraps processing functions @@ -142,27 +143,22 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg messages.FromFe } } - return nil + return gtserror.Newf("unhandled: %s %s", fMsg.APActivityType, fMsg.APObjectType) } func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) error { var ( status *gtsmodel.Status err error - - // Check the federatorMsg for either an already dereferenced - // and converted status pinned to the message, or a forwarded - // AP IRI that we still need to deref. - forwarded = (fMsg.GTSModel == nil) ) - if forwarded { - // Model was not set, deref with IRI. + if fMsg.APObjectModel == nil /* i.e. forwarded */ { + // Model was not set, deref with IRI (this is a forward). // This will also cause the status to be inserted into the db. status, err = p.statusFromAPIRI(ctx, fMsg) } else { // Model is set, ensure we have the most up-to-date model. - status, err = p.statusFromGTSModel(ctx, fMsg) + status, err = p.statusFromAPModel(ctx, fMsg) } if err != nil { @@ -188,19 +184,10 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e } } - // Ensure status ancestors dereferenced. We need at least the - // immediate parent (if present) to ascertain timelineability. - if err := p.federate.DereferenceStatusAncestors( - ctx, - fMsg.ReceivingAccount.Username, - status, - ); err != nil { - return err - } - if status.InReplyToID != "" { - // Interaction counts changed on the replied status; - // uncache the prepared version from all timelines. + // Interaction counts changed on the replied status; uncache the + // prepared version from all timelines. The status dereferencer + // functions will ensure necessary ancestors exist before this point. p.surface.invalidateStatusFromTimelines(ctx, status.InReplyToID) } @@ -211,23 +198,31 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e return nil } -func (p *fediAPI) statusFromGTSModel(ctx context.Context, fMsg messages.FromFediAPI) (*gtsmodel.Status, error) { - // There should be a status pinned to the message: - // we've already checked to ensure this is not nil. - status, ok := fMsg.GTSModel.(*gtsmodel.Status) +func (p *fediAPI) statusFromAPModel(ctx context.Context, fMsg messages.FromFediAPI) (*gtsmodel.Status, error) { + // AP statusable representation MUST have been set. + statusable, ok := fMsg.APObjectModel.(ap.Statusable) if !ok { - err := gtserror.New("Note was not parseable as *gtsmodel.Status") - return nil, err + return nil, gtserror.Newf("cannot cast %T -> ap.Statusable", fMsg.APObjectModel) } - // AP statusable representation may have also - // been set on message (no problem if not). - statusable, _ := fMsg.APObjectModel.(ap.Statusable) + // Status may have been set (no problem if not). + status, _ := fMsg.GTSModel.(*gtsmodel.Status) - // Call refresh on status to update - // it (deref remote) if necessary. - var err error - status, _, err = p.federate.RefreshStatus( + if status == nil { + // No status was set, create a bare-bones + // model for the deferencer to flesh-out, + // this indicates it is a new (to us) status. + status = >smodel.Status{ + + // if coming in here status will ALWAYS be remote. + Local: util.Ptr(false), + URI: ap.GetJSONLDId(statusable).String(), + } + } + + // Call refresh on status to either update existing + // model, or parse + insert status from statusable data. + status, _, err := p.federate.RefreshStatus( ctx, fMsg.ReceivingAccount.Username, status, @@ -235,7 +230,7 @@ func (p *fediAPI) statusFromGTSModel(ctx context.Context, fMsg messages.FromFedi false, // Don't force refresh. ) if err != nil { - return nil, gtserror.Newf("%w", err) + return nil, gtserror.Newf("error refreshing status: %w", err) } return status, nil @@ -245,11 +240,8 @@ func (p *fediAPI) statusFromAPIRI(ctx context.Context, fMsg messages.FromFediAPI // There should be a status IRI pinned to // the federatorMsg for us to dereference. if fMsg.APIri == nil { - err := gtserror.New( - "status was not pinned to federatorMsg, " + - "and neither was an IRI for us to dereference", - ) - return nil, err + const text = "neither APObjectModel nor APIri set" + return nil, gtserror.New(text) } // Get the status + ensure we have @@ -260,7 +252,7 @@ func (p *fediAPI) statusFromAPIRI(ctx context.Context, fMsg messages.FromFediAPI fMsg.APIri, ) if err != nil { - return nil, gtserror.Newf("%w", err) + return nil, gtserror.Newf("error getting status by uri %s: %w", fMsg.APIri, err) } return status, nil @@ -337,7 +329,9 @@ func (p *fediAPI) CreateAnnounce(ctx context.Context, fMsg messages.FromFediAPI) return gtserror.Newf("%T not parseable as *gtsmodel.Status", fMsg.GTSModel) } - // Dereference status that this status boosts. + // Dereference status that this boosts, note + // that this will handle dereferencing the status + // ancestors / descendants where appropriate. if err := p.federate.DereferenceAnnounce( ctx, status, @@ -358,15 +352,6 @@ func (p *fediAPI) CreateAnnounce(ctx context.Context, fMsg messages.FromFediAPI) return gtserror.Newf("db error inserting status: %w", err) } - // Ensure boosted status ancestors dereferenced. We need at least - // the immediate parent (if present) to ascertain timelineability. - if err := p.federate.DereferenceStatusAncestors(ctx, - fMsg.ReceivingAccount.Username, - status.BoostOf, - ); err != nil { - return err - } - // Timeline and notify the announce. if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { return gtserror.Newf("error timelining status: %w", err) @@ -526,23 +511,25 @@ func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) e } // Cast the updated ActivityPub statusable object . - apStatus, ok := fMsg.APObjectModel.(ap.Statusable) - if !ok { - return gtserror.Newf("cannot cast %T -> ap.Statusable", fMsg.APObjectModel) - } + apStatus, _ := fMsg.APObjectModel.(ap.Statusable) // Fetch up-to-date attach status attachments, etc. - _, _, err := p.federate.RefreshStatus( + _, statusable, err := p.federate.RefreshStatus( ctx, fMsg.ReceivingAccount.Username, existing, apStatus, - false, + true, ) if err != nil { return gtserror.Newf("error refreshing updated status: %w", err) } + if statusable != nil { + // Status representation was refetched, uncache from timelines. + p.surface.invalidateStatusFromTimelines(ctx, existing.ID) + } + return nil } diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index f8e3941fc..b8d86ac45 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -29,7 +29,6 @@ import ( apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/util" @@ -92,54 +91,33 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() { repliedStatus := suite.testStatuses["local_account_1_status_1"] replyingAccount := suite.testAccounts["remote_account_1"] - replyingStatus := >smodel.Status{ - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - URI: "http://fossbros-anonymous.io/users/foss_satan/statuses/106221634728637552", - URL: "http://fossbros-anonymous.io/@foss_satan/106221634728637552", - Content: `

@the_mighty_zork nice there it is:

social.pixie.town/users/f0x/st

`, - Mentions: []*gtsmodel.Mention{ - { - TargetAccountURI: repliedAccount.URI, - NameString: "@the_mighty_zork@localhost:8080", - }, - }, - AccountID: replyingAccount.ID, - AccountURI: replyingAccount.URI, - InReplyToID: repliedStatus.ID, - InReplyToURI: repliedStatus.URI, - InReplyToAccountID: repliedAccount.ID, - Visibility: gtsmodel.VisibilityUnlocked, - ActivityStreamsType: ap.ObjectNote, - Federated: util.Ptr(true), - Boostable: util.Ptr(true), - Replyable: util.Ptr(true), - Likeable: util.Ptr(false), - } + // Set the replyingAccount's last fetched_at + // date to something recent so no refresh is attempted. + replyingAccount.FetchedAt = time.Now() + err := suite.state.DB.UpdateAccount(context.Background(), replyingAccount, "fetched_at") + suite.NoError(err) + // Get replying statusable to use from remote test statuses. + const replyingURI = "http://fossbros-anonymous.io/users/foss_satan/statuses/106221634728637552" + replyingStatusable := testrig.NewTestFediStatuses()[replyingURI] + ap.AppendInReplyTo(replyingStatusable, testrig.URLMustParse(repliedStatus.URI)) + + // Open a websocket stream to later test the streamed status reply. wssStream, errWithCode := suite.processor.Stream().Open(context.Background(), repliedAccount, stream.TimelineHome) suite.NoError(errWithCode) - // id the status based on the time it was created - statusID, err := id.NewULIDFromTime(replyingStatus.CreatedAt) - suite.NoError(err) - replyingStatus.ID = statusID - - err = suite.db.PutStatus(context.Background(), replyingStatus) - suite.NoError(err) - + // Send the replied status off to the fedi worker to be further processed. err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, - GTSModel: replyingStatus, + APObjectModel: replyingStatusable, ReceivingAccount: suite.testAccounts["local_account_1"], }) suite.NoError(err) // side effects should be triggered // 1. status should be in the database - suite.NotEmpty(replyingStatus.ID) - _, err = suite.db.GetStatusByID(context.Background(), replyingStatus.ID) + replyingStatus, err := suite.state.DB.GetStatusByURI(context.Background(), replyingURI) suite.NoError(err) // 2. a notification should exist for the mention diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go new file mode 100644 index 000000000..b4cbcf5f3 --- /dev/null +++ b/internal/scheduler/scheduler.go @@ -0,0 +1,131 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package scheduler + +import ( + "context" + "sync" + "time" + + "codeberg.org/gruf/go-runners" + "codeberg.org/gruf/go-sched" +) + +// Scheduler wraps an underlying task scheduler +// to provide concurrency safe tracking by 'id' +// strings in order to provide easy cancellation. +type Scheduler struct { + sch sched.Scheduler + ts map[string]*task + mu sync.Mutex +} + +// Start will start the Scheduler background routine, returning success. +// Note that this creates a new internal task map, stopping and dropping +// all previously known running tasks. +func (sch *Scheduler) Start() bool { + if sch.sch.Start(nil) { + sch.ts = make(map[string]*task) + return true + } + return false +} + +// Stop will stop the Scheduler background routine, returning success. +// Note that this nils-out the internal task map, stopping and dropping +// all previously known running tasks. +func (sch *Scheduler) Stop() bool { + if sch.sch.Stop() { + sch.ts = nil + return true + } + return false +} + +// AddOnce adds a run-once job with given id, function and timing parameters, returning success. +func (sch *Scheduler) AddOnce(id string, start time.Time, fn func(context.Context, time.Time)) bool { + return sch.schedule(id, fn, (*sched.Once)(&start)) +} + +// AddRecurring adds a new recurring job with given id, function and timing parameters, returning success. +func (sch *Scheduler) AddRecurring(id string, start time.Time, freq time.Duration, fn func(context.Context, time.Time)) bool { + return sch.schedule(id, fn, &sched.PeriodicAt{Once: sched.Once(start), Period: sched.Periodic(freq)}) +} + +// Cancel will attempt to cancel job with given id, +// dropping it from internal scheduler and task map. +func (sch *Scheduler) Cancel(id string) bool { + // Attempt to acquire and + // delete task with iD. + sch.mu.Lock() + task, ok := sch.ts[id] + delete(sch.ts, id) + sch.mu.Unlock() + + if !ok { + // none found. + return false + } + + // Cancel the queued + // job from Scheduler. + task.cncl() + return true +} + +func (sch *Scheduler) schedule(id string, fn func(context.Context, time.Time), t sched.Timing) bool { + if fn == nil { + panic("nil function") + } + + // Perform within lock. + sch.mu.Lock() + defer sch.mu.Unlock() + + if _, ok := sch.ts[id]; ok { + // existing task already + // exists under this ID. + return false + } + + // Extract current sched context. + doneCh := sch.sch.Done() + ctx := runners.CancelCtx(doneCh) + + // Create a new job to hold task function with + // timing, passing in the current sched context. + job := sched.NewJob(func(now time.Time) { + fn(ctx, now) + }) + job.With(t) + + // Queue job with the scheduler, + // and store a new encompassing task. + cncl := sch.sch.Schedule(job) + sch.ts[id] = &task{ + job: job, + cncl: cncl, + } + + return true +} + +type task struct { + job *sched.Job + cncl func() +} diff --git a/internal/workers/workers.go b/internal/workers/workers.go index 8f884d427..3617ce333 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -23,13 +23,13 @@ import ( "runtime" "codeberg.org/gruf/go-runners" - "codeberg.org/gruf/go-sched" "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/scheduler" ) type Workers struct { // Main task scheduler instance. - Scheduler sched.Scheduler + Scheduler scheduler.Scheduler // ClientAPI provides a worker pool that handles both // incoming client actions, and our own side-effects. @@ -70,9 +70,7 @@ func (w *Workers) Start() { // Get currently set GOMAXPROCS. maxprocs := runtime.GOMAXPROCS(0) - tryUntil("starting scheduler", 5, func() bool { - return w.Scheduler.Start(nil) - }) + tryUntil("starting scheduler", 5, w.Scheduler.Start) tryUntil("starting client API workerpool", 5, func() bool { return w.ClientAPI.Start(4*maxprocs, 400*maxprocs) diff --git a/testrig/federatingdb.go b/testrig/federatingdb.go index a1215a7ba..d66a82306 100644 --- a/testrig/federatingdb.go +++ b/testrig/federatingdb.go @@ -21,9 +21,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" ) // NewTestFederatingDB returns a federating DB with the underlying db func NewTestFederatingDB(state *state.State) federatingdb.DB { - return federatingdb.New(state, typeutils.NewConverter(state)) + return federatingdb.New(state, typeutils.NewConverter(state), visibility.NewFilter(state)) } diff --git a/testrig/util.go b/testrig/util.go index 9512e3b6b..8ffc1b8ea 100644 --- a/testrig/util.go +++ b/testrig/util.go @@ -41,7 +41,7 @@ func StartWorkers(state *state.State) { state.Workers.ProcessFromClientAPI = func(context.Context, messages.FromClientAPI) error { return nil } state.Workers.ProcessFromFediAPI = func(context.Context, messages.FromFediAPI) error { return nil } - _ = state.Workers.Scheduler.Start(nil) + _ = state.Workers.Scheduler.Start() _ = state.Workers.ClientAPI.Start(1, 10) _ = state.Workers.Federator.Start(1, 10) _ = state.Workers.Media.Start(1, 10)