diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go
index d43599b05..bc56e21f0 100644
--- a/cmd/gotosocial/action/server/server.go
+++ b/cmd/gotosocial/action/server/server.go
@@ -35,7 +35,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"go.uber.org/automaxprocs/maxprocs"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -45,7 +44,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing"
@@ -107,19 +105,11 @@ var Start action.GTSAction = func(ctx context.Context) error {
state.Workers.Start()
defer state.Workers.Stop()
- // Create the client API and federator worker pools
- // NOTE: these MUST NOT be used until they are passed to the
- // processor and it is started. The reason being that the processor
- // sets the Worker process functions and start the underlying pools
- // TODO: move these into state.Workers (and maybe reformat worker pools).
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
// build backend handlers
mediaManager := media.NewManager(&state)
oauthServer := oauth.New(ctx, dbService)
typeConverter := typeutils.NewConverter(dbService)
- federatingDB := federatingdb.New(dbService, fedWorker, typeConverter)
+ federatingDB := federatingdb.New(&state, typeConverter)
transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client)
federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager)
@@ -140,11 +130,15 @@ var Start action.GTSAction = func(ctx context.Context) error {
}
// create the message processor using the other services we've created so far
- processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, storage, dbService, emailSender, clientWorker, fedWorker)
+ processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender)
if err := processor.Start(); err != nil {
return fmt.Errorf("error creating processor: %s", err)
}
+ // Set state client / federator worker enqueue functions
+ state.Workers.EnqueueClientAPI = processor.EnqueueClientAPI
+ state.Workers.EnqueueFederator = processor.EnqueueFederator
+
/*
HTTP router initialization
*/
diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go
index 3be7907fe..68bb94ec3 100644
--- a/cmd/gotosocial/action/testrig/testrig.go
+++ b/cmd/gotosocial/action/testrig/testrig.go
@@ -33,14 +33,13 @@ import (
"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/api"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/web"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -48,37 +47,44 @@ import (
// Start creates and starts a gotosocial testrig server
var Start action.GTSAction = func(ctx context.Context) error {
+ var state state.State
+
testrig.InitTestConfig()
testrig.InitTestLog()
- dbService := testrig.NewTestDB()
- testrig.StandardDBSetup(dbService, nil)
- var storageBackend *storage.Driver
- if os.Getenv("GTS_STORAGE_BACKEND") == "s3" {
- storageBackend, _ = storage.NewS3Storage()
- } else {
- storageBackend = testrig.NewInMemoryStorage()
- }
- testrig.StandardStorageSetup(storageBackend, "./testrig/media")
+ // Initialize caches
+ state.Caches.Init()
+ state.Caches.Start()
+ defer state.Caches.Stop()
- // Create client API and federator worker pools
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ state.DB = testrig.NewTestDB(&state)
+ testrig.StandardDBSetup(state.DB, nil)
+
+ if os.Getenv("GTS_STORAGE_BACKEND") == "s3" {
+ state.Storage, _ = storage.NewS3Storage()
+ } else {
+ state.Storage = testrig.NewInMemoryStorage()
+ }
+ testrig.StandardStorageSetup(state.Storage, "./testrig/media")
+
+ // Initialize workers.
+ state.Workers.Start()
+ defer state.Workers.Stop()
// build backend handlers
- transportController := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
+ transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
r := io.NopCloser(bytes.NewReader([]byte{}))
return &http.Response{
StatusCode: 200,
Body: r,
}, nil
- }, ""), dbService, fedWorker)
- mediaManager := testrig.NewTestMediaManager(dbService, storageBackend)
- federator := testrig.NewTestFederator(dbService, transportController, storageBackend, mediaManager, fedWorker)
+ }, ""))
+ mediaManager := testrig.NewTestMediaManager(&state)
+ federator := testrig.NewTestFederator(&state, transportController, mediaManager)
emailSender := testrig.NewEmailSender("./web/template/", nil)
- processor := testrig.NewTestProcessor(dbService, storageBackend, federator, emailSender, mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&state, federator, emailSender, mediaManager)
if err := processor.Start(); err != nil {
return fmt.Errorf("error starting processor: %s", err)
}
@@ -87,7 +93,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
HTTP router initialization
*/
- router := testrig.NewTestRouter(dbService)
+ router := testrig.NewTestRouter(state.DB)
// attach global middlewares which are used for every request
router.AttachGlobalMiddleware(
@@ -112,7 +118,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
}
}
- routerSession, err := dbService.GetSession(ctx)
+ routerSession, err := state.DB.GetSession(ctx)
if err != nil {
return fmt.Errorf("error retrieving router session for session middleware: %w", err)
}
@@ -123,13 +129,13 @@ var Start action.GTSAction = func(ctx context.Context) error {
}
var (
- authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths
- clientModule = api.NewClient(dbService, processor) // api client endpoints
- fileserverModule = api.NewFileserver(processor) // fileserver endpoints
- wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints
- nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint
- activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints
- webModule = web.New(dbService, processor) // web pages + user profiles + settings panels etc
+ authModule = api.NewAuth(state.DB, processor, idp, routerSession, sessionName) // auth/oauth paths
+ clientModule = api.NewClient(state.DB, processor) // api client endpoints
+ fileserverModule = api.NewFileserver(processor) // fileserver endpoints
+ wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints
+ nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint
+ activityPubModule = api.NewActivityPub(state.DB, processor) // ActivityPub endpoints
+ webModule = web.New(state.DB, processor) // web pages + user profiles + settings panels etc
)
// these should be routed in order
@@ -142,7 +148,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
activityPubModule.RoutePublicKey(router)
webModule.Route(router)
- gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager)
+ gts, err := gotosocial.NewServer(state.DB, router, federator, mediaManager)
if err != nil {
return fmt.Errorf("error creating gotosocial service: %s", err)
}
@@ -157,8 +163,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
sig := <-sigs
log.Infof(ctx, "received signal %s, shutting down", sig)
- testrig.StandardDBTeardown(dbService)
- testrig.StandardStorageTeardown(storageBackend)
+ testrig.StandardDBTeardown(state.DB)
+ testrig.StandardStorageTeardown(state.Storage)
// close down all running services in order
if err := gts.Stop(ctx); err != nil {
diff --git a/internal/api/activitypub/emoji/emojiget_test.go b/internal/api/activitypub/emoji/emojiget_test.go
index cd7333955..8f99efdfc 100644
--- a/internal/api/activitypub/emoji/emojiget_test.go
+++ b/internal/api/activitypub/emoji/emojiget_test.go
@@ -27,15 +27,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/emoji"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -50,6 +49,7 @@ type EmojiGetTestSuite struct {
emailSender email.Sender
processor *processing.Processor
storage *storage.Driver
+ state state.State
testEmojis map[string]*gtsmodel.Emoji
testAccounts map[string]*gtsmodel.Account
@@ -65,19 +65,23 @@ func (suite *EmojiGetTestSuite) SetupSuite() {
}
func (suite *EmojiGetTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- suite.db = testrig.NewTestDB()
- suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+
+ suite.tc = testrig.NewTestTypeConverter(suite.db)
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.emojiModule = emoji.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -90,6 +94,7 @@ func (suite *EmojiGetTestSuite) SetupTest() {
func (suite *EmojiGetTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *EmojiGetTestSuite) TestGetEmoji() {
diff --git a/internal/api/activitypub/users/inboxpost_test.go b/internal/api/activitypub/users/inboxpost_test.go
index 0ad63abf7..fa23204c9 100644
--- a/internal/api/activitypub/users/inboxpost_test.go
+++ b/internal/api/activitypub/users/inboxpost_test.go
@@ -34,11 +34,9 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -86,13 +84,10 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
@@ -190,13 +185,10 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
@@ -291,9 +283,6 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
// use a different version of the mock http client which serves the updated
// version of the remote account, as though it had been updated there too;
// this is needed so it can be dereferenced + updated properly
@@ -301,10 +290,11 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
mockHTTPClient.TestRemotePeople = map[string]vocab.ActivityStreamsPerson{
updatedAccount.URI: asAccount,
}
- tc := testrig.NewTestTransportController(mockHTTPClient, suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+
+ tc := testrig.NewTestTransportController(&suite.state, mockHTTPClient)
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
@@ -430,15 +420,12 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
- suite.NoError(processor.Start())
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
+ suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
diff --git a/internal/api/activitypub/users/outboxget_test.go b/internal/api/activitypub/users/outboxget_test.go
index 6e5c4e1e0..8f3306a25 100644
--- a/internal/api/activitypub/users/outboxget_test.go
+++ b/internal/api/activitypub/users/outboxget_test.go
@@ -32,8 +32,6 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -104,13 +102,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
@@ -182,13 +177,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
diff --git a/internal/api/activitypub/users/repliesget_test.go b/internal/api/activitypub/users/repliesget_test.go
index 4e985a0a1..92e5cddfa 100644
--- a/internal/api/activitypub/users/repliesget_test.go
+++ b/internal/api/activitypub/users/repliesget_test.go
@@ -33,8 +33,6 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -104,13 +102,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
@@ -172,13 +167,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
- federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
+ federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
- processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker)
+ processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
diff --git a/internal/api/activitypub/users/user_test.go b/internal/api/activitypub/users/user_test.go
index 0124925b9..d025eada0 100644
--- a/internal/api/activitypub/users/user_test.go
+++ b/internal/api/activitypub/users/user_test.go
@@ -22,15 +22,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -46,6 +45,7 @@ type UserStandardTestSuite struct {
emailSender email.Sender
processor *processing.Processor
storage *storage.Driver
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -75,19 +75,21 @@ func (suite *UserStandardTestSuite) SetupSuite() {
}
func (suite *UserStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.userModule = users.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -100,4 +102,5 @@ func (suite *UserStandardTestSuite) SetupTest() {
func (suite *UserStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go
index a5e518cda..1a15155bd 100644
--- a/internal/api/auth/auth_test.go
+++ b/internal/api/auth/auth_test.go
@@ -28,17 +28,16 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/auth"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -47,6 +46,7 @@ type AuthStandardTestSuite struct {
suite.Suite
db db.DB
storage *storage.Driver
+ state state.State
mediaManager media.Manager
federator federation.Federator
processor *processing.Processor
@@ -78,18 +78,19 @@ func (suite *AuthStandardTestSuite) SetupSuite() {
}
func (suite *AuthStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.authModule = auth.New(suite.db, suite.processor, suite.idp)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
diff --git a/internal/api/client/accounts/account_test.go b/internal/api/client/accounts/account_test.go
index 5a25c12f1..ab3f4cd1f 100644
--- a/internal/api/client/accounts/account_test.go
+++ b/internal/api/client/accounts/account_test.go
@@ -27,16 +27,15 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -51,6 +50,7 @@ type AccountStandardTestSuite struct {
processor *processing.Processor
emailSender email.Sender
sentEmails map[string]string
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -76,19 +76,22 @@ func (suite *AccountStandardTestSuite) SetupSuite() {
}
func (suite *AccountStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.accountsModule = accounts.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -99,6 +102,7 @@ func (suite *AccountStandardTestSuite) SetupTest() {
func (suite *AccountStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *AccountStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context {
diff --git a/internal/api/client/admin/admin_test.go b/internal/api/client/admin/admin_test.go
index 4f3f48904..1d19635f0 100644
--- a/internal/api/client/admin/admin_test.go
+++ b/internal/api/client/admin/admin_test.go
@@ -27,16 +27,15 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -51,6 +50,7 @@ type AdminStandardTestSuite struct {
processor *processing.Processor
emailSender email.Sender
sentEmails map[string]string
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -82,19 +82,22 @@ func (suite *AdminStandardTestSuite) SetupSuite() {
}
func (suite *AdminStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.adminModule = admin.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -103,6 +106,7 @@ func (suite *AdminStandardTestSuite) SetupTest() {
func (suite *AdminStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *AdminStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context {
diff --git a/internal/api/client/bookmarks/bookmarks_test.go b/internal/api/client/bookmarks/bookmarks_test.go
index c39ad49f3..931d504f7 100644
--- a/internal/api/client/bookmarks/bookmarks_test.go
+++ b/internal/api/client/bookmarks/bookmarks_test.go
@@ -32,16 +32,15 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/statuses"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -57,6 +56,7 @@ type BookmarkTestSuite struct {
emailSender email.Sender
processor *processing.Processor
storage *storage.Driver
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -87,22 +87,25 @@ func (suite *BookmarkTestSuite) SetupSuite() {
}
func (suite *BookmarkTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
- suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
+ suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor)
suite.bookmarkModule = bookmarks.New(suite.processor)
@@ -112,6 +115,7 @@ func (suite *BookmarkTestSuite) SetupTest() {
func (suite *BookmarkTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *BookmarkTestSuite) getBookmarks(
diff --git a/internal/api/client/favourites/favourites_test.go b/internal/api/client/favourites/favourites_test.go
index 7949aa38c..71c7097cc 100644
--- a/internal/api/client/favourites/favourites_test.go
+++ b/internal/api/client/favourites/favourites_test.go
@@ -21,14 +21,13 @@ package favourites_test
import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/favourites"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -44,6 +43,7 @@ type FavouritesStandardTestSuite struct {
emailSender email.Sender
processor *processing.Processor
storage *storage.Driver
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -71,22 +71,25 @@ func (suite *FavouritesStandardTestSuite) SetupSuite() {
}
func (suite *FavouritesStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
- suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
+ suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.favModule = favourites.New(suite.processor)
suite.NoError(suite.processor.Start())
@@ -95,6 +98,7 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {
func (suite *FavouritesStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *FavouritesStandardTestSuite) TestProcessFave() {}
diff --git a/internal/api/client/followrequests/followrequest_test.go b/internal/api/client/followrequests/followrequest_test.go
index 7a08479ab..294dbc7ed 100644
--- a/internal/api/client/followrequests/followrequest_test.go
+++ b/internal/api/client/followrequests/followrequest_test.go
@@ -26,16 +26,15 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -48,6 +47,7 @@ type FollowRequestStandardTestSuite struct {
federator federation.Federator
processor *processing.Processor
emailSender email.Sender
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -73,18 +73,21 @@ func (suite *FollowRequestStandardTestSuite) SetupSuite() {
}
func (suite *FollowRequestStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.followRequestModule = followrequests.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -95,6 +98,7 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
func (suite *FollowRequestStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *FollowRequestStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context {
diff --git a/internal/api/client/instance/instance_test.go b/internal/api/client/instance/instance_test.go
index ff622febe..6870d2a44 100644
--- a/internal/api/client/instance/instance_test.go
+++ b/internal/api/client/instance/instance_test.go
@@ -26,16 +26,15 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -50,6 +49,7 @@ type InstanceStandardTestSuite struct {
processor *processing.Processor
emailSender email.Sender
sentEmails map[string]string
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -75,19 +75,22 @@ func (suite *InstanceStandardTestSuite) SetupSuite() {
}
func (suite *InstanceStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.instanceModule = instance.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -96,6 +99,7 @@ func (suite *InstanceStandardTestSuite) SetupTest() {
func (suite *InstanceStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *InstanceStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, method string, path string, body []byte, contentType string, auth bool) *gin.Context {
diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go
index caa40b061..6439895f3 100644
--- a/internal/api/client/media/mediacreate_test.go
+++ b/internal/api/client/media/mediacreate_test.go
@@ -33,7 +33,6 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -41,9 +40,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -60,6 +59,7 @@ type MediaCreateTestSuite struct {
oauthServer oauth.Server
emailSender email.Sender
processor *processing.Processor
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -78,21 +78,24 @@ type MediaCreateTestSuite struct {
*/
func (suite *MediaCreateTestSuite) SetupSuite() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
// setup standard items
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
// setup module being tested
suite.mediaModule = mediamodule.New(suite.processor)
@@ -102,11 +105,15 @@ func (suite *MediaCreateTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
log.Panicf(nil, "error closing db connection: %s", err)
}
+ testrig.StopWorkers(&suite.state)
}
func (suite *MediaCreateTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
+
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go
index cb96e8aa1..75657e1b5 100644
--- a/internal/api/client/media/mediaupdate_test.go
+++ b/internal/api/client/media/mediaupdate_test.go
@@ -31,7 +31,6 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -39,9 +38,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -58,6 +57,7 @@ type MediaUpdateTestSuite struct {
oauthServer oauth.Server
emailSender email.Sender
processor *processing.Processor
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -76,21 +76,23 @@ type MediaUpdateTestSuite struct {
*/
func (suite *MediaUpdateTestSuite) SetupSuite() {
+ testrig.StartWorkers(&suite.state)
+
// setup standard items
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
// setup module being tested
suite.mediaModule = mediamodule.New(suite.processor)
@@ -100,11 +102,15 @@ func (suite *MediaUpdateTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
log.Panicf(nil, "error closing db connection: %s", err)
}
+ testrig.StopWorkers(&suite.state)
}
func (suite *MediaUpdateTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
+
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
diff --git a/internal/api/client/reports/reports_test.go b/internal/api/client/reports/reports_test.go
index 1c5a532b9..cdab0b77b 100644
--- a/internal/api/client/reports/reports_test.go
+++ b/internal/api/client/reports/reports_test.go
@@ -21,14 +21,13 @@ package reports_test
import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/reports"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -42,6 +41,7 @@ type ReportsStandardTestSuite struct {
processor *processing.Processor
emailSender email.Sender
sentEmails map[string]string
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -67,19 +67,22 @@ func (suite *ReportsStandardTestSuite) SetupSuite() {
}
func (suite *ReportsStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.reportsModule = reports.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -90,4 +93,5 @@ func (suite *ReportsStandardTestSuite) SetupTest() {
func (suite *ReportsStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/api/client/search/search_test.go b/internal/api/client/search/search_test.go
index 4580f6f9d..153328cc3 100644
--- a/internal/api/client/search/search_test.go
+++ b/internal/api/client/search/search_test.go
@@ -26,16 +26,15 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/search"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -50,6 +49,7 @@ type SearchStandardTestSuite struct {
processor *processing.Processor
emailSender email.Sender
sentEmails map[string]string
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -71,19 +71,22 @@ func (suite *SearchStandardTestSuite) SetupSuite() {
}
func (suite *SearchStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.searchModule = search.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -94,6 +97,7 @@ func (suite *SearchStandardTestSuite) SetupTest() {
func (suite *SearchStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func (suite *SearchStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestPath string) *gin.Context {
diff --git a/internal/api/client/statuses/status_test.go b/internal/api/client/statuses/status_test.go
index a87fd36f7..93745ffd8 100644
--- a/internal/api/client/statuses/status_test.go
+++ b/internal/api/client/statuses/status_test.go
@@ -21,14 +21,13 @@ package statuses_test
import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/statuses"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -44,6 +43,7 @@ type StatusStandardTestSuite struct {
emailSender email.Sender
processor *processing.Processor
storage *storage.Driver
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -71,22 +71,26 @@ func (suite *StatusStandardTestSuite) SetupSuite() {
}
func (suite *StatusStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
- suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
+ suite.tc = testrig.NewTestTypeConverter(suite.db)
+
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor)
suite.NoError(suite.processor.Start())
@@ -95,4 +99,5 @@ func (suite *StatusStandardTestSuite) SetupTest() {
func (suite *StatusStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go
index 5fb470af8..ac27aad8a 100644
--- a/internal/api/client/streaming/streaming_test.go
+++ b/internal/api/client/streaming/streaming_test.go
@@ -32,15 +32,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -56,6 +55,7 @@ type StreamingTestSuite struct {
emailSender email.Sender
processor *processing.Processor
storage *storage.Driver
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -83,22 +83,25 @@ func (suite *StreamingTestSuite) SetupSuite() {
}
func (suite *StreamingTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
- suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
+ suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.streamingModule = streaming.New(suite.processor, 1, 4096)
suite.NoError(suite.processor.Start())
}
@@ -106,6 +109,7 @@ func (suite *StreamingTestSuite) SetupTest() {
func (suite *StreamingTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
// Addr is a fake network interface which implements the net.Addr interface
diff --git a/internal/api/client/user/user_test.go b/internal/api/client/user/user_test.go
index c990abb56..ce117059e 100644
--- a/internal/api/client/user/user_test.go
+++ b/internal/api/client/user/user_test.go
@@ -21,14 +21,13 @@ package user_test
import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -43,6 +42,7 @@ type UserStandardTestSuite struct {
emailSender email.Sender
processor *processing.Processor
storage *storage.Driver
+ state state.State
testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
@@ -56,23 +56,29 @@ type UserStandardTestSuite struct {
}
func (suite *UserStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
- suite.db = testrig.NewTestDB()
+
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.userModule = user.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@@ -83,4 +89,5 @@ func (suite *UserStandardTestSuite) SetupTest() {
func (suite *UserStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/api/fileserver/fileserver_test.go b/internal/api/fileserver/fileserver_test.go
index 0a6879e70..0e0dd9434 100644
--- a/internal/api/fileserver/fileserver_test.go
+++ b/internal/api/fileserver/fileserver_test.go
@@ -23,16 +23,15 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/fileserver"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -43,6 +42,7 @@ type FileserverTestSuite struct {
suite.Suite
db db.DB
storage *storage.Driver
+ state state.State
federator federation.Federator
tc typeutils.TypeConverter
processor *processing.Processor
@@ -67,26 +67,32 @@ type FileserverTestSuite struct {
*/
func (suite *FileserverTestSuite) SetupSuite() {
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
- suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
+ suite.state.Storage = suite.storage
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker)
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
+ suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
suite.fileServer = fileserver.New(suite.processor)
}
func (suite *FileserverTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
suite.testTokens = testrig.NewTestTokens()
@@ -101,9 +107,11 @@ func (suite *FileserverTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
log.Panicf(nil, "error closing db connection: %s", err)
}
+ testrig.StopWorkers(&suite.state)
}
func (suite *FileserverTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/api/wellknown/webfinger/webfinger_test.go b/internal/api/wellknown/webfinger/webfinger_test.go
index 38228e928..3148279c5 100644
--- a/internal/api/wellknown/webfinger/webfinger_test.go
+++ b/internal/api/wellknown/webfinger/webfinger_test.go
@@ -26,15 +26,14 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -44,6 +43,7 @@ type WebfingerStandardTestSuite struct {
// standard suite interfaces
suite.Suite
db db.DB
+ state state.State
tc typeutils.TypeConverter
mediaManager media.Manager
federator federation.Federator
@@ -76,19 +76,21 @@ func (suite *WebfingerStandardTestSuite) SetupSuite() {
}
func (suite *WebfingerStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestLog()
testrig.InitTestConfig()
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Storage = suite.storage
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.webfingerModule = webfinger.New(suite.processor)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
@@ -100,6 +102,7 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
func (suite *WebfingerStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
func accountDomainAccount() *gtsmodel.Account {
diff --git a/internal/api/wellknown/webfinger/webfingerget_test.go b/internal/api/wellknown/webfinger/webfingerget_test.go
index 7587dfee1..a345d0602 100644
--- a/internal/api/wellknown/webfinger/webfingerget_test.go
+++ b/internal/api/wellknown/webfinger/webfingerget_test.go
@@ -30,9 +30,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -91,9 +89,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
config.SetHost("gts.example.org")
config.SetAccountDomain("example.org")
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
+ suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender)
suite.webfingerModule = webfinger.New(suite.processor)
targetAccount := accountDomainAccount()
@@ -148,9 +144,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAc
config.SetHost("gts.example.org")
config.SetAccountDomain("example.org")
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
+ suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender)
suite.webfingerModule = webfinger.New(suite.processor)
targetAccount := accountDomainAccount()
diff --git a/internal/concurrency/workers.go b/internal/concurrency/workers.go
deleted file mode 100644
index ed99509cf..000000000
--- a/internal/concurrency/workers.go
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package concurrency
-
-import (
- "context"
- "errors"
- "fmt"
- "path"
- "reflect"
- "runtime"
-
- "codeberg.org/gruf/go-kv"
- "codeberg.org/gruf/go-runners"
- "github.com/superseriousbusiness/gotosocial/internal/log"
-)
-
-// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
-type WorkerPool[MsgType any] struct {
- workers runners.WorkerPool
- process func(context.Context, MsgType) error
- nw, nq int
- wtype string // contains worker type for logging
-}
-
-// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
-// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
-// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
-func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
- var zero MsgType
-
- if workers < 1 {
- // ensure sensible workers
- workers = runtime.GOMAXPROCS(0) * 4
- }
- if queueRatio < 1 {
- // ensure sensible ratio
- queueRatio = 100
- }
-
- // Calculate the short type string for the msg type
- msgType := reflect.TypeOf(zero).String()
- _, msgType = path.Split(msgType)
-
- w := &WorkerPool[MsgType]{
- process: nil,
- nw: workers,
- nq: workers * queueRatio,
- wtype: fmt.Sprintf("worker.Worker[%s]", msgType),
- }
-
- // Log new worker creation with worker type prefix
- log.Infof(nil, "%s created with workers=%d queue=%d",
- w.wtype,
- workers,
- workers*queueRatio,
- )
-
- return w
-}
-
-// Start will attempt to start the underlying worker pool, or return error.
-func (w *WorkerPool[MsgType]) Start() error {
- log.Infof(nil, "%s starting", w.wtype)
-
- // Check processor was set
- if w.process == nil {
- return errors.New("nil Worker.process function")
- }
-
- // Attempt to start pool
- if !w.workers.Start(w.nw, w.nq) {
- return errors.New("failed to start Worker pool")
- }
-
- return nil
-}
-
-// Stop will attempt to stop the underlying worker pool, or return error.
-func (w *WorkerPool[MsgType]) Stop() error {
- log.Infof(nil, "%s stopping", w.wtype)
-
- // Attempt to stop pool
- if !w.workers.Stop() {
- return errors.New("failed to stop Worker pool")
- }
-
- return nil
-}
-
-// SetProcessor will set the Worker's processor function, which is called for each queued message.
-func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
- if w.process != nil {
- log.Panicf(nil, "%s Worker.process is already set", w.wtype)
- }
- w.process = fn
-}
-
-// Queue will queue provided message to be processed with there's a free worker.
-func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
- log.Tracef(nil, "%s queueing message: %+v", w.wtype, msg)
-
- // Create new process function for msg
- process := func(ctx context.Context) {
- if err := w.process(ctx, msg); err != nil {
- log.WithContext(ctx).
- WithFields(kv.Fields{
- kv.Field{K: "type", V: w.wtype},
- kv.Field{K: "error", V: err},
- }...).Error("message processing error")
- }
- }
-
- // Attempt a fast-enqueue of process
- if !w.workers.EnqueueNow(process) {
- // No spot acquired, log warning
- log.WithFields(kv.Fields{
- kv.Field{K: "type", V: w.wtype},
- kv.Field{K: "queue", V: w.workers.Queue()},
- }...).Warn("full worker queue")
-
- // Block on enqueuing process func
- w.workers.Enqueue(process)
- }
-}
diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go
index b0da97ef1..ce255d036 100644
--- a/internal/db/bundb/admin_test.go
+++ b/internal/db/bundb/admin_test.go
@@ -70,8 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
}
func (suite *AdminTestSuite) TestCreateInstanceAccount() {
- // reinitialize test DB to clear caches
- suite.db = testrig.NewTestDB()
+ // reinitialize db caches to clear
+ suite.state.Caches.Init()
// we need to take an empty db for this...
testrig.StandardDBTeardown(suite.db)
// ...with tables created but no data
diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go
index e050c2b5d..bad8bfc72 100644
--- a/internal/db/bundb/bundb_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -22,13 +22,15 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
- db db.DB
+ db db.DB
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -61,9 +63,10 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
}
func (suite *BunDBStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
diff --git a/internal/federation/dereferencing/dereferencer_test.go b/internal/federation/dereferencing/dereferencer_test.go
index daca8b7de..f5b59b0ed 100644
--- a/internal/federation/dereferencing/dereferencer_test.go
+++ b/internal/federation/dereferencing/dereferencer_test.go
@@ -21,11 +21,10 @@ package dereferencing_test
import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams/vocab"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -34,6 +33,7 @@ type DereferencerStandardTestSuite struct {
suite.Suite
db db.DB
storage *storage.Driver
+ state state.State
testRemoteStatuses map[string]vocab.ActivityStreamsNote
testRemotePeople map[string]vocab.ActivityStreamsPerson
@@ -58,12 +58,19 @@ func (suite *DereferencerStandardTestSuite) SetupTest() {
suite.testRemoteAttachments = testrig.NewTestFediAttachments("../../../testrig/media")
suite.testEmojis = testrig.NewTestEmojis()
- suite.db = testrig.NewTestDB()
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
+ suite.db = testrig.NewTestDB(&suite.state)
suite.storage = testrig.NewInMemoryStorage()
- suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1)), testrig.NewTestMediaManager(suite.db, suite.storage))
+ suite.state.DB = suite.db
+ suite.state.Storage = suite.storage
+ media := testrig.NewTestMediaManager(&suite.state)
+ suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), media)
testrig.StandardDBSetup(suite.db, nil)
}
func (suite *DereferencerStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/federation/federatingactor_test.go b/internal/federation/federatingactor_test.go
index 0d1d8e37f..f63ecd827 100644
--- a/internal/federation/federatingactor_test.go
+++ b/internal/federation/federatingactor_test.go
@@ -27,10 +27,8 @@ import (
"time"
"github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -56,14 +54,12 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
// setup transport controller with a no-op client so we don't make external calls
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity)
suite.NoError(err)
@@ -105,12 +101,10 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), testrig.TimeMustParse("2022-06-02T12:22:21+02:00"), testNote)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity)
suite.NoError(err)
diff --git a/internal/federation/federatingdb/accept.go b/internal/federation/federatingdb/accept.go
index d3e227a10..184d2b09d 100644
--- a/internal/federation/federatingdb/accept.go
+++ b/internal/federation/federatingdb/accept.go
@@ -65,7 +65,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if uris.IsFollowPath(acceptedObjectIRI) {
// ACCEPT FOLLOW
gtsFollowRequest := >smodel.FollowRequest{}
- if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {
+ if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {
return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err)
}
@@ -73,12 +73,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if gtsFollowRequest.AccountID != receivingAccount.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same")
}
- follow, err := f.db.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)
+ follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)
if err != nil {
return err
}
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept,
GTSModel: follow,
@@ -108,12 +108,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if gtsFollow.AccountID != receivingAccount.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same")
}
- follow, err := f.db.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID)
+ follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID)
if err != nil {
return err
}
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept,
GTSModel: follow,
diff --git a/internal/federation/federatingdb/announce.go b/internal/federation/federatingdb/announce.go
index f4d145148..552a95ba9 100644
--- a/internal/federation/federatingdb/announce.go
+++ b/internal/federation/federatingdb/announce.go
@@ -59,7 +59,7 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre
}
// it's a new announce so pass it back to the processor async for dereferencing etc
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityCreate,
GTSModel: boost,
diff --git a/internal/federation/federatingdb/announce_test.go b/internal/federation/federatingdb/announce_test.go
index 6c0d969f4..d9158f383 100644
--- a/internal/federation/federatingdb/announce_test.go
+++ b/internal/federation/federatingdb/announce_test.go
@@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
)
type AnnounceTestSuite struct {
@@ -74,6 +75,13 @@ func (suite *AnnounceTestSuite) TestAnnounceTwice() {
suite.True(ok)
suite.Equal(announcingAccount.ID, boost.AccountID)
+ // Insert the boost-of status into the
+ // DB cache to emulate processor handling
+ boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt)
+ suite.state.Caches.GTS.Status().Store(boost, func() error {
+ return nil
+ })
+
// only the URI will be set on the boosted status because it still needs to be dereferenced
suite.NotEmpty(boost.BoostOf.URI)
diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go
index bf3e7f75d..ca87131fe 100644
--- a/internal/federation/federatingdb/create.go
+++ b/internal/federation/federatingdb/create.go
@@ -103,11 +103,11 @@ func (f *federatingDB) activityBlock(ctx context.Context, asType vocab.Type, rec
block.ID = id.NewULID()
- if err := f.db.PutBlock(ctx, block); err != nil {
+ if err := f.state.DB.PutBlock(ctx, block); err != nil {
return fmt.Errorf("activityBlock: database error inserting block: %s", err)
}
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityBlock,
APActivityType: ap.ActivityCreate,
GTSModel: block,
@@ -202,7 +202,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream
return nil
}
// pass the note iri into the processor and have it do the dereferencing instead of doing it here
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
APIri: id.GetIRI(),
@@ -226,7 +226,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream
}
status.ID = statusID
- if err := f.db.PutStatus(ctx, status); err != nil {
+ if err := f.state.DB.PutStatus(ctx, status); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
// the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it.
@@ -236,7 +236,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream
return fmt.Errorf("createNote: database error inserting status: %s", err)
}
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
GTSModel: status,
@@ -263,11 +263,11 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re
followRequest.ID = id.NewULID()
- if err := f.db.Put(ctx, followRequest); err != nil {
+ if err := f.state.DB.Put(ctx, followRequest); err != nil {
return fmt.Errorf("activityFollow: database error inserting follow request: %s", err)
}
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityCreate,
GTSModel: followRequest,
@@ -294,11 +294,11 @@ func (f *federatingDB) activityLike(ctx context.Context, asType vocab.Type, rece
fave.ID = id.NewULID()
- if err := f.db.Put(ctx, fave); err != nil {
+ if err := f.state.DB.Put(ctx, fave); err != nil {
return fmt.Errorf("activityLike: database error inserting fave: %s", err)
}
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityLike,
APActivityType: ap.ActivityCreate,
GTSModel: fave,
@@ -325,11 +325,11 @@ func (f *federatingDB) activityFlag(ctx context.Context, asType vocab.Type, rece
report.ID = id.NewULID()
- if err := f.db.PutReport(ctx, report); err != nil {
+ if err := f.state.DB.PutReport(ctx, report); err != nil {
return fmt.Errorf("activityFlag: database error inserting report: %w", err)
}
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFlag,
APActivityType: ap.ActivityCreate,
GTSModel: report,
diff --git a/internal/federation/federatingdb/db.go b/internal/federation/federatingdb/db.go
index 24455a553..af4aceeeb 100644
--- a/internal/federation/federatingdb/db.go
+++ b/internal/federation/federatingdb/db.go
@@ -24,9 +24,7 @@ import (
"codeberg.org/gruf/go-mutexes"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams/vocab"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
)
@@ -43,17 +41,15 @@ type DB interface {
// It doesn't care what the underlying implementation of the DB interface is, as long as it works.
type federatingDB struct {
locks mutexes.MutexMap
- db db.DB
- fedWorker *concurrency.WorkerPool[messages.FromFederator]
+ state *state.State
typeConverter typeutils.TypeConverter
}
// New returns a DB interface using the given database and config
-func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator], tc typeutils.TypeConverter) DB {
+func New(state *state.State, tc typeutils.TypeConverter) DB {
fdb := federatingDB{
locks: mutexes.NewMap(-1, -1), // use defaults
- db: db,
- fedWorker: fedWorker,
+ state: state,
typeConverter: tc,
}
return &fdb
diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go
index a1890b56b..695f199b4 100644
--- a/internal/federation/federatingdb/delete.go
+++ b/internal/federation/federatingdb/delete.go
@@ -51,9 +51,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
// in a delete we only get the URI, we can't know if we have a status or a profile or something else,
// so we have to try a few different things...
- if s, err := f.db.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID {
+ if s, err := f.state.DB.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID {
l.Debugf("uri is for STATUS with id: %s", s.ID)
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete,
GTSModel: s,
@@ -61,9 +61,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
})
}
- if a, err := f.db.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID {
+ if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID {
l.Debugf("uri is for ACCOUNT with id %s", a.ID)
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityDelete,
GTSModel: a,
diff --git a/internal/federation/federatingdb/federatingdb_test.go b/internal/federation/federatingdb/federatingdb_test.go
index dd5a5f5f9..b0893f246 100644
--- a/internal/federation/federatingdb/federatingdb_test.go
+++ b/internal/federation/federatingdb/federatingdb_test.go
@@ -23,11 +23,11 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -36,9 +36,9 @@ type FederatingDBTestSuite struct {
suite.Suite
db db.DB
tc typeutils.TypeConverter
- fedWorker *concurrency.WorkerPool[messages.FromFederator]
fromFederator chan messages.FromFederator
federatingDB federatingdb.DB
+ state state.State
testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client
@@ -66,22 +66,33 @@ func (suite *FederatingDBTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
suite.fromFederator = make(chan messages.FromFederator, 10)
- suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error {
+ suite.state.Workers.EnqueueFederator = func(ctx context.Context, msg messages.FromFederator) {
suite.fromFederator <- msg
- return nil
- })
- _ = suite.fedWorker.Start()
- suite.db = testrig.NewTestDB()
+ }
+
+ suite.db = testrig.NewTestDB(&suite.state)
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.federatingDB = testrig.NewTestFederatingDB(suite.db, suite.fedWorker)
+ suite.federatingDB = testrig.NewTestFederatingDB(&suite.state)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
+
+ suite.state.DB = suite.db
}
func (suite *FederatingDBTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
+ testrig.StopWorkers(&suite.state)
+ for suite.fromFederator != nil {
+ select {
+ case <-suite.fromFederator:
+ default:
+ return
+ }
+ }
}
func createTestContext(receivingAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account) context.Context {
diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go
index c47a2b625..69746c99b 100644
--- a/internal/federation/federatingdb/followers.go
+++ b/internal/federation/federatingdb/followers.go
@@ -29,7 +29,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
- acctFollowers, err := f.db.GetAccountFollowedBy(ctx, acct.ID, false)
+ acctFollowers, err := f.state.DB.GetAccountFollowedBy(ctx, acct.ID, false)
if err != nil {
return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err)
}
@@ -37,7 +37,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
iris := []*url.URL{}
for _, follow := range acctFollowers {
if follow.Account == nil {
- a, err := f.db.GetAccountByID(ctx, follow.AccountID)
+ a, err := f.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil {
errWrapped := fmt.Errorf("Followers: db error getting account id %s: %s", follow.AccountID, err)
if err == db.ErrNoEntries {
diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go
index f4f07bb25..9c22c0574 100644
--- a/internal/federation/federatingdb/following.go
+++ b/internal/federation/federatingdb/following.go
@@ -47,7 +47,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
- acctFollowing, err := f.db.GetAccountFollows(ctx, acct.ID)
+ acctFollowing, err := f.state.DB.GetAccountFollows(ctx, acct.ID)
if err != nil {
return nil, fmt.Errorf("Following: db error getting following for account id %s: %s", acct.ID, err)
}
@@ -55,7 +55,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
iris := []*url.URL{}
for _, follow := range acctFollowing {
if follow.TargetAccount == nil {
- a, err := f.db.GetAccountByID(ctx, follow.TargetAccountID)
+ a, err := f.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil {
errWrapped := fmt.Errorf("Following: db error getting account id %s: %s", follow.TargetAccountID, err)
if err == db.ErrNoEntries {
diff --git a/internal/federation/federatingdb/get.go b/internal/federation/federatingdb/get.go
index 92a79d70f..1d687f110 100644
--- a/internal/federation/federatingdb/get.go
+++ b/internal/federation/federatingdb/get.go
@@ -39,13 +39,13 @@ func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type,
switch {
case uris.IsUserPath(id):
- acct, err := f.db.GetAccountByURI(ctx, id.String())
+ acct, err := f.state.DB.GetAccountByURI(ctx, id.String())
if err != nil {
return nil, err
}
return f.typeConverter.AccountToAS(ctx, acct)
case uris.IsStatusesPath(id):
- status, err := f.db.GetStatusByURI(ctx, id.String())
+ status, err := f.state.DB.GetStatusByURI(ctx, id.String())
if err != nil {
return nil, err
}
diff --git a/internal/federation/federatingdb/inbox.go b/internal/federation/federatingdb/inbox.go
index 5ec735bd4..1a6da4ef0 100644
--- a/internal/federation/federatingdb/inbox.go
+++ b/internal/federation/federatingdb/inbox.go
@@ -85,12 +85,12 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
return nil, fmt.Errorf("couldn't extract local account username from uri %s: %s", iri, err)
}
- account, err := f.db.GetAccountByUsernameDomain(c, localAccountUsername, "")
+ account, err := f.state.DB.GetAccountByUsernameDomain(c, localAccountUsername, "")
if err != nil {
return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err)
}
- follows, err := f.db.GetAccountFollowedBy(c, account.ID, false)
+ follows, err := f.state.DB.GetAccountFollowedBy(c, account.ID, false)
if err != nil {
return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err)
}
@@ -98,7 +98,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
for _, follow := range follows {
// make sure we retrieved the following account from the db
if follow.Account == nil {
- followingAccount, err := f.db.GetAccountByID(c, follow.AccountID)
+ followingAccount, err := f.state.DB.GetAccountByID(c, follow.AccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
@@ -126,7 +126,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
}
// check if this is just an account IRI...
- if account, err := f.db.GetAccountByURI(c, iri.String()); err == nil {
+ if account, err := f.state.DB.GetAccountByURI(c, iri.String()); err == nil {
// deliver to a shared inbox if we have that option
var inbox string
if config.GetInstanceDeliverToSharedInboxes() && account.SharedInboxURI != nil && *account.SharedInboxURI != "" {
diff --git a/internal/federation/federatingdb/owns.go b/internal/federation/federatingdb/owns.go
index def0fa518..2c11e8148 100644
--- a/internal/federation/federatingdb/owns.go
+++ b/internal/federation/federatingdb/owns.go
@@ -54,7 +54,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- status, err := f.db.GetStatusByURI(ctx, uid)
+ status, err := f.state.DB.GetStatusByURI(ctx, uid)
if err != nil {
if err == db.ErrNoEntries {
// there are no entries for this status
@@ -71,7 +71,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
+ if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -88,7 +88,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
+ if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -105,7 +105,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
+ if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -122,7 +122,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
+ if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -130,7 +130,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
// an actual error happened
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
- if err := f.db.GetByID(ctx, likeID, >smodel.StatusFave{}); err != nil {
+ if err := f.state.DB.GetByID(ctx, likeID, >smodel.StatusFave{}); err != nil {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
@@ -147,7 +147,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
+ if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -155,7 +155,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
// an actual error happened
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
- if err := f.db.GetByID(ctx, blockID, >smodel.Block{}); err != nil {
+ if err := f.state.DB.GetByID(ctx, blockID, >smodel.Block{}); err != nil {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
diff --git a/internal/federation/federatingdb/reject.go b/internal/federation/federatingdb/reject.go
index 3c3cd7c75..d443cd6cb 100644
--- a/internal/federation/federatingdb/reject.go
+++ b/internal/federation/federatingdb/reject.go
@@ -64,7 +64,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
if uris.IsFollowPath(rejectedObjectIRI) {
// REJECT FOLLOW
gtsFollowRequest := >smodel.FollowRequest{}
- if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil {
+ if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil {
return fmt.Errorf("Reject: couldn't get follow request with id %s from the database: %s", rejectedObjectIRI.String(), err)
}
@@ -73,7 +73,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
return errors.New("Reject: follow object account and inbox account were not the same")
}
- if _, err := f.db.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil {
+ if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil {
return err
}
@@ -102,7 +102,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
if gtsFollow.AccountID != receivingAccount.ID {
return errors.New("Reject: follow object account and inbox account were not the same")
}
- if _, err := f.db.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {
+ if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {
return err
}
diff --git a/internal/federation/federatingdb/undo.go b/internal/federation/federatingdb/undo.go
index b239aabb4..e33b365fa 100644
--- a/internal/federation/federatingdb/undo.go
+++ b/internal/federation/federatingdb/undo.go
@@ -81,11 +81,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: follow object account and inbox account were not the same")
}
// delete any existing FOLLOW
- if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil {
+ if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil {
return fmt.Errorf("UNDO: db error removing follow: %s", err)
}
// delete any existing FOLLOW REQUEST
- if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil {
+ if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil {
return fmt.Errorf("UNDO: db error removing follow request: %s", err)
}
l.Debug("follow undone")
@@ -114,7 +114,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: block object account and inbox account were not the same")
}
// delete any existing BLOCK
- if err := f.db.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil {
+ if err := f.state.DB.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil {
return fmt.Errorf("UNDO: db error removing block: %s", err)
}
l.Debug("block undone")
diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go
index 570729a31..bed5de4db 100644
--- a/internal/federation/federatingdb/update.go
+++ b/internal/federation/federatingdb/update.go
@@ -138,7 +138,7 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
// pass to the processor for further updating of eg., avatar/header, emojis
// the actual db insert/update will take place a bit later
- f.fedWorker.Queue(messages.FromFederator{
+ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate,
GTSModel: updatedAcct,
diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go
index 64f32d39c..f63eb6dc9 100644
--- a/internal/federation/federatingdb/util.go
+++ b/internal/federation/federatingdb/util.go
@@ -95,7 +95,7 @@ func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (idURL *url.URL,
// take the IRI of the first actor we can find (there should only be one)
if iter.IsIRI() {
// if there's an error here, just use the fallback behavior -- we don't need to return an error here
- if actorAccount, err := f.db.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil {
+ if actorAccount, err := f.state.DB.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil {
newID, err := id.NewRandomULID()
if err != nil {
return nil, err
@@ -238,7 +238,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
switch {
case uris.IsUserPath(iri):
- if acct, err = f.db.GetAccountByURI(ctx, iri.String()); err != nil {
+ if acct, err = f.state.DB.GetAccountByURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to uri %s", iri.String())
}
@@ -246,7 +246,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsInboxPath(iri):
- if err = f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil {
+ if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", iri.String())
}
@@ -254,7 +254,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsOutboxPath(iri):
- if err = f.db.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil {
+ if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", iri.String())
}
@@ -262,7 +262,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsFollowersPath(iri):
- if err = f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil {
+ if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to followers_uri %s", iri.String())
}
@@ -270,7 +270,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsFollowingPath(iri):
- if err = f.db.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil {
+ if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to following_uri %s", iri.String())
}
diff --git a/internal/federation/federatingprotocol_test.go b/internal/federation/federatingprotocol_test.go
index faa168a71..e66cd78cb 100644
--- a/internal/federation/federatingprotocol_test.go
+++ b/internal/federation/federatingprotocol_test.go
@@ -28,10 +28,8 @@ import (
"github.com/go-fed/httpsig"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -43,12 +41,10 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook1() {
// the activity we're gonna use
activity := suite.testActivities["dm_for_zork"]
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
// setup request
ctx := context.Background()
@@ -74,13 +70,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook2() {
// the activity we're gonna use
activity := suite.testActivities["reply_to_turtle_for_zork"]
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
// setup request
ctx := context.Background()
@@ -107,13 +101,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook3() {
// the activity we're gonna use
activity := suite.testActivities["reply_to_turtle_for_turtle"]
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
// setup request
ctx := context.Background()
@@ -142,13 +134,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// now setup module being tested, with the mock transport controller
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)
// we need these headers for the request to be validated
@@ -187,13 +177,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGone() {
activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"]
inboxAccount := suite.testAccounts["local_account_1"]
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// now setup module being tested, with the mock transport controller
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)
// we need these headers for the request to be validated
@@ -231,13 +219,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet
activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"]
inboxAccount := suite.testAccounts["local_account_1"]
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
// now setup module being tested, with the mock transport controller
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)
// we need these headers for the request to be validated
@@ -271,10 +257,9 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet
}
func (suite *FederatingProtocolTestSuite) TestBlocked1() {
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
@@ -294,10 +279,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked1() {
}
func (suite *FederatingProtocolTestSuite) TestBlocked2() {
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
@@ -328,10 +312,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked2() {
}
func (suite *FederatingProtocolTestSuite) TestBlocked3() {
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
@@ -365,10 +348,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked3() {
}
func (suite *FederatingProtocolTestSuite) TestBlocked4() {
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
- tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
- federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage))
+ tc := testrig.NewTestTransportController(&suite.state, httpClient)
+ federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
diff --git a/internal/federation/federator_test.go b/internal/federation/federator_test.go
index da6038ace..8a045aa1f 100644
--- a/internal/federation/federator_test.go
+++ b/internal/federation/federator_test.go
@@ -23,6 +23,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -32,6 +33,7 @@ type FederatorStandardTestSuite struct {
suite.Suite
db db.DB
storage *storage.Driver
+ state state.State
tc typeutils.TypeConverter
testAccounts map[string]*gtsmodel.Account
testStatuses map[string]*gtsmodel.Status
@@ -42,8 +44,9 @@ type FederatorStandardTestSuite struct {
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *FederatorStandardTestSuite) SetupSuite() {
// setup standard items
+ testrig.StartWorkers(&suite.state)
suite.storage = testrig.NewInMemoryStorage()
- suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.state.Storage = suite.storage
suite.testAccounts = testrig.NewTestAccounts()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTombstones = testrig.NewTestTombstones()
@@ -52,7 +55,10 @@ func (suite *FederatorStandardTestSuite) SetupSuite() {
func (suite *FederatorStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.state.Caches.Init()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.state.DB = suite.db
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
diff --git a/internal/media/media_test.go b/internal/media/media_test.go
index d9f01c1ff..393126ac7 100644
--- a/internal/media/media_test.go
+++ b/internal/media/media_test.go
@@ -20,11 +20,10 @@ package media_test
import (
"github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -35,6 +34,7 @@ type MediaStandardTestSuite struct {
db db.DB
storage *storage.Driver
+ state state.State
manager media.Manager
transportController transport.Controller
testAttachments map[string]*gtsmodel.MediaAttachment
@@ -46,21 +46,27 @@ func (suite *MediaStandardTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.DB = suite.db
+ suite.state.Storage = suite.storage
}
func (suite *MediaStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
testrig.StandardDBSetup(suite.db, nil)
suite.testAttachments = testrig.NewTestAttachments()
suite.testAccounts = testrig.NewTestAccounts()
suite.testEmojis = testrig.NewTestEmojis()
- suite.manager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](0, 0))
+ suite.manager = testrig.NewTestMediaManager(&suite.state)
+ suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../testrig/media"))
}
func (suite *MediaStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go
index 92c117bb3..a243383da 100644
--- a/internal/oauth/clientstore_test.go
+++ b/internal/oauth/clientstore_test.go
@@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
"github.com/superseriousbusiness/oauth2/v4/models"
)
@@ -32,6 +33,7 @@ import (
type PgClientStoreTestSuite struct {
suite.Suite
db db.DB
+ state state.State
testClientID string
testClientSecret string
testClientDomain string
@@ -48,9 +50,11 @@ func (suite *PgClientStoreTestSuite) SetupSuite() {
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
func (suite *PgClientStoreTestSuite) SetupTest() {
+ suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
testrig.StandardDBSetup(suite.db, nil)
}
diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go
index 41315d483..62330c0dc 100644
--- a/internal/processing/account/account.go
+++ b/internal/processing/account/account.go
@@ -19,13 +19,11 @@
package account
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
@@ -35,35 +33,32 @@ import (
//
// It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc.
type Processor struct {
+ state *state.State
tc typeutils.TypeConverter
mediaManager media.Manager
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
oauthServer oauth.Server
filter visibility.Filter
formatter text.Formatter
- db db.DB
federator federation.Federator
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new account processor.
func New(
- db db.DB,
+ state *state.State,
tc typeutils.TypeConverter,
mediaManager media.Manager,
oauthServer oauth.Server,
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
federator federation.Federator,
parseMention gtsmodel.ParseMentionFunc,
) Processor {
return Processor{
+ state: state,
tc: tc,
mediaManager: mediaManager,
- clientWorker: clientWorker,
oauthServer: oauthServer,
- filter: visibility.NewFilter(db),
- formatter: text.NewFormatter(db),
- db: db,
+ filter: visibility.NewFilter(state.DB),
+ formatter: text.NewFormatter(state.DB),
federator: federator,
parseMention: parseMention,
}
diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go
index 2e7cdb994..7a2e5aa8d 100644
--- a/internal/processing/account/account_test.go
+++ b/internal/processing/account/account_test.go
@@ -22,7 +22,6 @@ import (
"context"
"github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -32,6 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
@@ -44,6 +44,7 @@ type AccountStandardTestSuite struct {
db db.DB
tc typeutils.TypeConverter
storage *storage.Driver
+ state state.State
mediaManager media.Manager
oauthServer oauth.Server
fromClientAPIChan chan messages.FromClientAPI
@@ -76,30 +77,30 @@ func (suite *AccountStandardTestSuite) SetupSuite() {
}
func (suite *AccountStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestLog()
testrig.InitTestConfig()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
- suite.fromClientAPIChan <- msg
- return nil
- })
-
- _ = fedWorker.Start()
- _ = clientWorker.Start()
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
+ suite.state.Storage = suite.storage
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
+
suite.fromClientAPIChan = make(chan messages.FromClientAPI, 100)
- suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker)
- suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker)
+ suite.state.Workers.EnqueueClientAPI = func(ctx context.Context, msg messages.FromClientAPI) {
+ suite.fromClientAPIChan <- msg
+ }
+
+ suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
+ suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
- suite.accountProcessor = account.New(suite.db, suite.tc, suite.mediaManager, suite.oauthServer, clientWorker, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))
+ suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
}
@@ -107,4 +108,5 @@ func (suite *AccountStandardTestSuite) SetupTest() {
func (suite *AccountStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/processing/account/block.go b/internal/processing/account/block.go
index 99effd3a3..edec106b1 100644
--- a/internal/processing/account/block.go
+++ b/internal/processing/account/block.go
@@ -36,13 +36,13 @@ import (
// BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.
func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
- targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
+ targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
// if requestingAccount already blocks target account, we don't need to do anything
- if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {
+ if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))
} else if blocked {
return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
@@ -64,18 +64,18 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
block.URI = uris.GenerateURIForBlock(requestingAccount.Username, newBlockID)
// whack it in the database
- if err := p.db.PutBlock(ctx, block); err != nil {
+ if err := p.state.DB.PutBlock(ctx, block); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err))
}
// clear any follows or follow requests from the blocked account to the target account -- this is a simple delete
- if err := p.db.DeleteWhere(ctx, []db.Where{
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID},
}, >smodel.Follow{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err))
}
- if err := p.db.DeleteWhere(ctx, []db.Where{
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID},
}, >smodel.FollowRequest{}); err != nil {
@@ -89,12 +89,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
var frChanged bool
var frURI string
fr := >smodel.FollowRequest{}
- if err := p.db.GetWhere(ctx, []db.Where{
+ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil {
frURI = fr.URI
- if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
+ if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err))
}
frChanged = true
@@ -104,12 +104,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
var fChanged bool
var fURI string
f := >smodel.Follow{}
- if err := p.db.GetWhere(ctx, []db.Where{
+ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, f); err == nil {
fURI = f.URI
- if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
+ if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err))
}
fChanged = true
@@ -117,7 +117,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
// follow request status changed so send the UNDO activity to the channel for async processing
if frChanged {
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo,
GTSModel: >smodel.Follow{
@@ -132,7 +132,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
// follow status changed so send the UNDO activity to the channel for async processing
if fChanged {
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo,
GTSModel: >smodel.Follow{
@@ -146,7 +146,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
}
// handle the rest of the block process asynchronously
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityBlock,
APActivityType: ap.ActivityCreate,
GTSModel: block,
@@ -160,23 +160,23 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
// BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local.
func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
- targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
+ targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
// check if a block exists, and remove it if it does
- block, err := p.db.GetBlock(ctx, requestingAccount.ID, targetAccountID)
+ block, err := p.state.DB.GetBlock(ctx, requestingAccount.ID, targetAccountID)
if err == nil {
// we got a block, remove it
block.Account = requestingAccount
block.TargetAccount = targetAccount
- if err := p.db.DeleteBlockByID(ctx, block.ID); err != nil {
+ if err := p.state.DB.DeleteBlockByID(ctx, block.ID); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))
}
// send the UNDO activity to the client worker for async processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityBlock,
APActivityType: ap.ActivityUndo,
GTSModel: block,
diff --git a/internal/processing/account/bookmarks.go b/internal/processing/account/bookmarks.go
index 28688c20d..cf53e63bb 100644
--- a/internal/processing/account/bookmarks.go
+++ b/internal/processing/account/bookmarks.go
@@ -34,7 +34,7 @@ import (
// BookmarksGet returns a pageable response of statuses that are bookmarked by requestingAccount.
// Paging for this response is done based on bookmark ID rather than status ID.
func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmodel.Account, limit int, maxID string, minID string) (*apimodel.PageableResponse, gtserror.WithCode) {
- bookmarks, err := p.db.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID)
+ bookmarks, err := p.state.DB.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -47,7 +47,7 @@ func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmode
)
for _, bookmark := range bookmarks {
- status, err := p.db.GetStatusByID(ctx, bookmark.StatusID)
+ status, err := p.state.DB.GetStatusByID(ctx, bookmark.StatusID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// We just don't have the status for some reason.
diff --git a/internal/processing/account/create.go b/internal/processing/account/create.go
index 8b82bc681..9c9cfb57f 100644
--- a/internal/processing/account/create.go
+++ b/internal/processing/account/create.go
@@ -35,7 +35,7 @@ import (
// Create processes the given form for creating a new account, returning an oauth token for that account if successful.
func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, gtserror.WithCode) {
- emailAvailable, err := p.db.IsEmailAvailable(ctx, form.Email)
+ emailAvailable, err := p.state.DB.IsEmailAvailable(ctx, form.Email)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err)
}
@@ -43,7 +43,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", form.Email))
}
- usernameAvailable, err := p.db.IsUsernameAvailable(ctx, form.Username)
+ usernameAvailable, err := p.state.DB.IsUsernameAvailable(ctx, form.Username)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err)
}
@@ -61,7 +61,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
}
log.Trace(ctx, "creating new username and account")
- user, err := p.db.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false)
+ user, err := p.state.DB.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error creating new signup in the database: %s", err))
}
@@ -73,7 +73,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
}
if user.Account == nil {
- a, err := p.db.GetAccountByID(ctx, user.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, user.AccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting new account from the database: %s", err))
}
@@ -82,7 +82,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
// there are side effects for creating a new account (sending confirmation emails etc)
// so pass a message to the processor so that it can do it asynchronously
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityCreate,
GTSModel: user.Account,
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index 58a967337..eea4a621e 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -54,22 +54,22 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
if account.Domain == "" {
// see if we can get a user for this account
var err error
- if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil {
+ if user, err = p.state.DB.GetUserByAccountID(ctx, account.ID); err == nil {
// we got one! select all tokens with the user's ID
tokens := []*gtsmodel.Token{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
// we have some tokens to delete
for _, t := range tokens {
// delete client(s) associated with this token
- if err := p.db.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil {
+ if err := p.state.DB.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil {
l.Errorf("error deleting oauth client: %s", err)
}
// delete application(s) associated with this token
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil {
l.Errorf("error deleting application: %s", err)
}
// delete the token itself
- if err := p.db.DeleteByID(ctx, t.ID, t); err != nil {
+ if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil {
l.Errorf("error deleting oauth token: %s", err)
}
}
@@ -80,12 +80,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 2. Delete account's blocks
l.Trace("deleting account blocks")
// first delete any blocks that this account created
- if err := p.db.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {
+ if err := p.state.DB.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {
l.Errorf("error deleting blocks created by account: %s", err)
}
// now delete any blocks that target this account
- if err := p.db.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {
+ if err := p.state.DB.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {
l.Errorf("error deleting blocks targeting account: %s", err)
}
@@ -96,12 +96,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// TODO: federate these if necessary
l.Trace("deleting account follow requests")
// first delete any follow requests that this account created
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests created by account: %s", err)
}
// now delete any follow requests that target this account
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests targeting account: %s", err)
}
@@ -109,12 +109,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// TODO: federate these if necessary
l.Trace("deleting account follows")
// first delete any follows that this account created
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows created by account: %s", err)
}
// now delete any follows that target this account
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows targeting account: %s", err)
}
@@ -129,7 +129,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
for {
// Fetch next block of account statuses from database
- statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false)
+ statuses, err := p.state.DB.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// an actual error has occurred
@@ -149,7 +149,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
l.Tracef("queue client API status delete: %s", status.ID)
// pass the status delete through the client api channel for processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete,
GTSModel: status,
@@ -158,7 +158,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
})
// Look for any boosts of this status in DB
- boosts, err := p.db.GetStatusReblogs(ctx, status)
+ boosts, err := p.state.DB.GetStatusReblogs(ctx, status)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
l.Errorf("error fetching status reblogs for %q: %v", status.ID, err)
continue
@@ -167,7 +167,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
for _, boost := range boosts {
if boost.Account == nil {
// Fetch the relevant account for this status boost
- boostAcc, err := p.db.GetAccountByID(ctx, boost.AccountID)
+ boostAcc, err := p.state.DB.GetAccountByID(ctx, boost.AccountID)
if err != nil {
l.Errorf("error fetching boosted status account for %q: %v", boost.AccountID, err)
continue
@@ -180,7 +180,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
l.Tracef("queue client API boost delete: %s", status.ID)
// pass the boost delete through the client api channel for processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityUndo,
GTSModel: status,
@@ -197,31 +197,31 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 10. Delete account's notifications
l.Trace("deleting account notifications")
// first notifications created by account
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
l.Errorf("error deleting notifications created by account: %s", err)
}
// now notifications targeting account
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
l.Errorf("error deleting notifications targeting account: %s", err)
}
// 11. Delete account's bookmarks
l.Trace("deleting account bookmarks")
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
l.Errorf("error deleting bookmarks created by account: %s", err)
}
// 12. Delete account's faves
// TODO: federate these if necessary
l.Trace("deleting account faves")
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
l.Errorf("error deleting faves created by account: %s", err)
}
// 13. Delete account's mutes
l.Trace("deleting account mutes")
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
l.Errorf("error deleting status mutes created by account: %s", err)
}
@@ -234,7 +234,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 16. Delete account's user
if user != nil {
l.Trace("deleting account user")
- if err := p.db.DeleteUserByID(ctx, user.ID); err != nil {
+ if err := p.state.DB.DeleteUserByID(ctx, user.ID); err != nil {
return gtserror.NewErrorInternalError(err)
}
}
@@ -261,7 +261,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
account.Discoverable = &discoverable
account.SuspendedAt = time.Now()
account.SuspensionOrigin = origin
- err := p.db.UpdateAccount(ctx, account)
+ err := p.state.DB.UpdateAccount(ctx, account)
if err != nil {
return gtserror.NewErrorInternalError(err)
}
@@ -281,7 +281,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
if form.DeleteOriginID == account.ID {
// the account owner themself has requested deletion via the API, get their user from the db
- user, err := p.db.GetUserByAccountID(ctx, account.ID)
+ user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)
if err != nil {
return gtserror.NewErrorInternalError(err)
}
@@ -301,7 +301,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
} else {
// the delete has been requested by some other account, grab it;
// if we've reached this point we know it has permission already
- requestingAccount, err := p.db.GetAccountByID(ctx, form.DeleteOriginID)
+ requestingAccount, err := p.state.DB.GetAccountByID(ctx, form.DeleteOriginID)
if err != nil {
return gtserror.NewErrorInternalError(err)
}
@@ -310,7 +310,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
}
// put the delete in the processor queue to handle the rest of it asynchronously
- p.clientWorker.Queue(fromClientAPIMessage)
+ p.state.Workers.EnqueueClientAPI(ctx, fromClientAPIMessage)
return nil
}
diff --git a/internal/processing/account/follow.go b/internal/processing/account/follow.go
index d4d479be7..ac65c39f2 100644
--- a/internal/processing/account/follow.go
+++ b/internal/processing/account/follow.go
@@ -35,14 +35,14 @@ import (
// FollowCreate handles a follow request to an account, either remote or local.
func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't create the request ofc
- if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {
+ if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
// make sure the target account actually exists in our db
- targetAcct, err := p.db.GetAccountByID(ctx, form.ID)
+ targetAcct, err := p.state.DB.GetAccountByID(ctx, form.ID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
@@ -51,7 +51,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// check if a follow exists already
- if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {
+ if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))
} else if follows {
// already follows so just return the relationship
@@ -59,7 +59,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// check if a follow request exists already
- if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {
+ if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))
} else if followRequested {
// already follow requested so just return the relationship
@@ -95,13 +95,13 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// whack it in the database
- if err := p.db.Put(ctx, fr); err != nil {
+ if err := p.state.DB.Put(ctx, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err))
}
// if it's a local account that's not locked we can just straight up accept the follow request
if !*targetAcct.Locked && targetAcct.Domain == "" {
- if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
+ if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err))
}
// return the new relationship
@@ -109,7 +109,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// otherwise we leave the follow request as it is and we handle the rest of the process asynchronously
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityCreate,
GTSModel: fr,
@@ -124,7 +124,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
// FollowRemove handles the removal of a follow/follow request to an account, either remote or local.
func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't do anything
- blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -133,7 +133,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
}
// make sure the target account actually exists in our db
- targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID)
+ targetAcct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
@@ -144,12 +144,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
var frChanged bool
var frURI string
fr := >smodel.FollowRequest{}
- if err := p.db.GetWhere(ctx, []db.Where{
+ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil {
frURI = fr.URI
- if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
+ if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err))
}
frChanged = true
@@ -159,12 +159,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
var fChanged bool
var fURI string
f := >smodel.Follow{}
- if err := p.db.GetWhere(ctx, []db.Where{
+ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, f); err == nil {
fURI = f.URI
- if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
+ if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err))
}
fChanged = true
@@ -172,7 +172,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
// follow request status changed so send the UNDO activity to the channel for async processing
if frChanged {
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo,
GTSModel: >smodel.Follow{
@@ -187,7 +187,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
// follow status changed so send the UNDO activity to the channel for async processing
if fChanged {
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo,
GTSModel: >smodel.Follow{
diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go
index 11de1ddac..2c650254f 100644
--- a/internal/processing/account/get.go
+++ b/internal/processing/account/get.go
@@ -33,7 +33,7 @@ import (
// Get processes the given request for account information.
func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, gtserror.WithCode) {
- targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
+ targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
@@ -46,7 +46,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account
// GetLocalByUsername processes the given request for account information targeting a local account by username.
func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *gtsmodel.Account, username string) (*apimodel.Account, gtserror.WithCode) {
- targetAccount, err := p.db.GetAccountByUsernameDomain(ctx, username, "")
+ targetAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
@@ -59,7 +59,7 @@ func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *g
// GetCustomCSSForUsername returns custom css for the given local username.
func (p *Processor) GetCustomCSSForUsername(ctx context.Context, username string) (string, gtserror.WithCode) {
- customCSS, err := p.db.GetAccountCustomCSSByUsername(ctx, username)
+ customCSS, err := p.state.DB.GetAccountCustomCSSByUsername(ctx, username)
if err != nil {
if err == db.ErrNoEntries {
return "", gtserror.NewErrorNotFound(errors.New("account not found"))
@@ -74,7 +74,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco
var blocked bool
var err error
if requestingAccount != nil {
- blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)
+ blocked, err = p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err))
}
diff --git a/internal/processing/account/relationships.go b/internal/processing/account/relationships.go
index cb2789829..f60216f95 100644
--- a/internal/processing/account/relationships.go
+++ b/internal/processing/account/relationships.go
@@ -31,14 +31,14 @@ import (
// FollowersGet fetches a list of the target account's followers.
func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
accounts := []apimodel.Account{}
- follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false)
+ follows, err := p.state.DB.GetAccountFollowedBy(ctx, targetAccountID, false)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
@@ -47,7 +47,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
}
for _, f := range follows {
- blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -56,7 +56,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
}
if f.Account == nil {
- a, err := p.db.GetAccountByID(ctx, f.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, f.AccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
@@ -77,14 +77,14 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
// FollowingGet fetches a list of the accounts that target account is following.
func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
accounts := []apimodel.Account{}
- follows, err := p.db.GetAccountFollows(ctx, targetAccountID)
+ follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
@@ -93,7 +93,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
}
for _, f := range follows {
- blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -102,7 +102,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
}
if f.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, f.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, f.TargetAccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
@@ -127,7 +127,7 @@ func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsm
return nil, gtserror.NewErrorForbidden(errors.New("not authed"))
}
- gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
+ gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err))
}
diff --git a/internal/processing/account/rss.go b/internal/processing/account/rss.go
index 22065cf8e..61fcc1c51 100644
--- a/internal/processing/account/rss.go
+++ b/internal/processing/account/rss.go
@@ -34,7 +34,7 @@ const rssFeedLength = 20
// GetRSSFeedForUsername returns RSS feed for the given local username.
func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) (func() (string, gtserror.WithCode), time.Time, gtserror.WithCode) {
- account, err := p.db.GetAccountByUsernameDomain(ctx, username, "")
+ account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")
if err != nil {
if err == db.ErrNoEntries {
return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account not found"))
@@ -46,13 +46,13 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account RSS feed not enabled"))
}
- lastModified, err := p.db.GetAccountLastPosted(ctx, account.ID, true)
+ lastModified, err := p.state.DB.GetAccountLastPosted(ctx, account.ID, true)
if err != nil {
return nil, time.Time{}, gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))
}
return func() (string, gtserror.WithCode) {
- statuses, err := p.db.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "")
+ statuses, err := p.state.DB.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "")
if err != nil && err != db.ErrNoEntries {
return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))
}
@@ -65,7 +65,7 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
var image *feeds.Image
if account.AvatarMediaAttachmentID != "" {
if account.AvatarMediaAttachment == nil {
- avatar, err := p.db.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
+ avatar, err := p.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
if err != nil {
return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error fetching avatar attachment: %s", err))
}
diff --git a/internal/processing/account/statuses.go b/internal/processing/account/statuses.go
index 7ff6de2ff..9961dbdbe 100644
--- a/internal/processing/account/statuses.go
+++ b/internal/processing/account/statuses.go
@@ -33,7 +33,7 @@ import (
// the account given in authed.
func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) {
if requestingAccount != nil {
- if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
@@ -46,10 +46,10 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
)
if pinned {
// Get *ONLY* pinned statuses.
- statuses, err = p.db.GetAccountPinnedStatuses(ctx, targetAccountID)
+ statuses, err = p.state.DB.GetAccountPinnedStatuses(ctx, targetAccountID)
} else {
// Get account statuses which *may* include pinned ones.
- statuses, err = p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly)
+ statuses, err = p.state.DB.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly)
}
if err != nil {
if err == db.ErrNoEntries {
@@ -120,7 +120,7 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
// WebStatusesGet fetches a number of statuses (in descending order) from the given account. It selects only
// statuses which are suitable for showing on the public web profile of an account.
func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) {
- acct, err := p.db.GetAccountByID(ctx, targetAccountID)
+ acct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID)
@@ -134,7 +134,7 @@ func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string,
return nil, gtserror.NewErrorNotFound(err)
}
- statuses, err := p.db.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID)
+ statuses, err := p.state.DB.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID)
if err != nil {
if err == db.ErrNoEntries {
return util.EmptyPageableResponse(), nil
diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go
index cffbbb0c5..537857cee 100644
--- a/internal/processing/account/update.go
+++ b/internal/processing/account/update.go
@@ -165,12 +165,12 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, form
account.EnableRSS = form.EnableRSS
}
- err := p.db.UpdateAccount(ctx, account)
+ err := p.state.DB.UpdateAccount(ctx, account)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err))
}
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate,
GTSModel: account,
diff --git a/internal/processing/admin/account.go b/internal/processing/admin/account.go
index d23d1fbfe..ba4c5d4eb 100644
--- a/internal/processing/admin/account.go
+++ b/internal/processing/admin/account.go
@@ -31,7 +31,7 @@ import (
)
func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account, form *apimodel.AdminAccountActionRequest) gtserror.WithCode {
- targetAccount, err := p.db.GetAccountByID(ctx, form.TargetAccountID)
+ targetAccount, err := p.state.DB.GetAccountByID(ctx, form.TargetAccountID)
if err != nil {
return gtserror.NewErrorInternalError(err)
}
@@ -47,7 +47,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account
case string(gtsmodel.AdminActionSuspend):
adminAction.Type = gtsmodel.AdminActionSuspend
// pass the account delete through the client api channel for processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActorPerson,
APActivityType: ap.ActivityDelete,
OriginAccount: account,
@@ -57,7 +57,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account
return gtserror.NewErrorBadRequest(fmt.Errorf("admin action type %s is not supported for this endpoint", form.Type))
}
- if err := p.db.Put(ctx, adminAction); err != nil {
+ if err := p.state.DB.Put(ctx, adminAction); err != nil {
return gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go
index 54827b8fd..ba09969dc 100644
--- a/internal/processing/admin/admin.go
+++ b/internal/processing/admin/admin.go
@@ -19,32 +19,25 @@
package admin
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
)
type Processor struct {
+ state *state.State
tc typeutils.TypeConverter
mediaManager media.Manager
transportController transport.Controller
- storage *storage.Driver
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
- db db.DB
}
// New returns a new admin processor.
-func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
+func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {
return Processor{
+ state: state,
tc: tc,
mediaManager: mediaManager,
transportController: transportController,
- storage: storage,
- clientWorker: clientWorker,
- db: db,
}
}
diff --git a/internal/processing/admin/domainblock.go b/internal/processing/admin/domainblock.go
index 415ac610f..dd22f72e6 100644
--- a/internal/processing/admin/domainblock.go
+++ b/internal/processing/admin/domainblock.go
@@ -28,7 +28,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
domain = strings.ToLower(domain)
// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work
- block, err := p.db.GetDomainBlock(ctx, domain)
+ block, err := p.state.DB.GetDomainBlock(ctx, domain)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// something went wrong in the DB
@@ -47,7 +47,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
}
// Insert the new block into the database
- if err := p.db.CreateDomainBlock(ctx, newBlock); err != nil {
+ if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err))
}
@@ -80,7 +80,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
// if we have an instance entry for this domain, update it with the new block ID and clear all fields
instance := >smodel.Instance{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {
updatingColumns := []string{
"title",
"updated_at",
@@ -105,15 +105,15 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
instance.ContactAccountUsername = ""
instance.ContactAccountID = ""
instance.Version = ""
- if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
}
l.Debug("domainBlockProcessSideEffects: instance entry updated")
}
// if we have an instance account for this instance, delete it
- if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
- if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil {
+ if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
+ if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
}
}
@@ -125,7 +125,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
selectAccountsLoop:
for {
- accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
+ accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
if err != nil {
if err == db.ErrNoEntries {
// no accounts left for this instance so we're done
@@ -141,7 +141,7 @@ selectAccountsLoop:
l.Debugf("putting delete for account %s in the clientAPI channel", a.Username)
// pass the account delete through the client api channel for processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActorPerson,
APActivityType: ap.ActivityDelete,
GTSModel: block,
@@ -195,7 +195,7 @@ func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Ac
func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
domainBlocks := []*gtsmodel.DomainBlock{}
- if err := p.db.GetAll(ctx, &domainBlocks); err != nil {
+ if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
@@ -219,7 +219,7 @@ func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Accou
func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := >smodel.DomainBlock{}
- if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
+ if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
@@ -240,7 +240,7 @@ func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Accoun
func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := >smodel.DomainBlock{}
- if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
+ if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
@@ -256,13 +256,13 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
}
// Delete the domain block
- if err := p.db.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
+ if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// remove the domain block reference from the instance, if we have an entry for it
i := >smodel.Instance{}
- if err := p.db.GetWhere(ctx, []db.Where{
+ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "domain", Value: domainBlock.Domain},
{Key: "domain_block_id", Value: id},
}, i); err == nil {
@@ -270,21 +270,21 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
i.SuspendedAt = time.Time{}
i.DomainBlockID = ""
i.UpdatedAt = time.Now()
- if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
}
}
// unsuspend all accounts whose suspension origin was this domain block
// 1. remove the 'suspended_at' entry from their accounts
- if err := p.db.UpdateWhere(ctx, []db.Where{
+ if err := p.state.DB.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))
}
// 2. remove the 'suspension_origin' entry from their accounts
- if err := p.db.UpdateWhere(ctx, []db.Where{
+ if err := p.state.DB.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err))
diff --git a/internal/processing/admin/emoji.go b/internal/processing/admin/emoji.go
index 391d18525..3eacbf888 100644
--- a/internal/processing/admin/emoji.go
+++ b/internal/processing/admin/emoji.go
@@ -42,7 +42,7 @@ func (p *Processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account,
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
}
- maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "")
+ maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "")
if maybeExisting != nil {
return nil, gtserror.NewErrorConflict(fmt.Errorf("emoji with shortcode %s already exists", form.Shortcode), fmt.Sprintf("emoji with shortcode %s already exists", form.Shortcode))
}
@@ -110,7 +110,7 @@ func (p *Processor) EmojisGet(
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
}
- emojis, err := p.db.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit)
+ emojis, err := p.state.DB.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err := fmt.Errorf("EmojisGet: db error: %s", err)
return nil, gtserror.NewErrorInternalError(err)
@@ -176,7 +176,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
}
- emoji, err := p.db.GetEmojiByID(ctx, id)
+ emoji, err := p.state.DB.GetEmojiByID(ctx, id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("EmojiGet: no emoji with id %s found in the db", id)
@@ -197,7 +197,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use
// EmojiDelete deletes one emoji from the database, with the given id.
func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.AdminEmoji, gtserror.WithCode) {
- emoji, err := p.db.GetEmojiByID(ctx, id)
+ emoji, err := p.state.DB.GetEmojiByID(ctx, id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("EmojiDelete: no emoji with id %s found in the db", id)
@@ -218,7 +218,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.db.DeleteEmojiByID(ctx, id); err != nil {
+ if err := p.state.DB.DeleteEmojiByID(ctx, id); err != nil {
err := fmt.Errorf("EmojiDelete: db error: %s", err)
return nil, gtserror.NewErrorInternalError(err)
}
@@ -228,7 +228,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin
// EmojiUpdate updates one emoji with the given id, using the provided form parameters.
func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.EmojiUpdateRequest) (*apimodel.AdminEmoji, gtserror.WithCode) {
- emoji, err := p.db.GetEmojiByID(ctx, id)
+ emoji, err := p.state.DB.GetEmojiByID(ctx, id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("EmojiUpdate: no emoji with id %s found in the db", id)
@@ -253,7 +253,7 @@ func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.E
// EmojiCategoriesGet returns all custom emoji categories that exist on this instance.
func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCategory, gtserror.WithCode) {
- categories, err := p.db.GetEmojiCategories(ctx)
+ categories, err := p.state.DB.GetEmojiCategories(ctx)
if err != nil {
err := fmt.Errorf("EmojiCategoriesGet: db error: %s", err)
return nil, gtserror.NewErrorInternalError(err)
@@ -277,7 +277,7 @@ func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCa
*/
func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) {
- category, err := p.db.GetEmojiCategoryByName(ctx, name)
+ category, err := p.state.DB.GetEmojiCategoryByName(ctx, name)
if err == nil {
return category, nil
}
@@ -299,7 +299,7 @@ func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (
Name: name,
}
- if err := p.db.PutEmojiCategory(ctx, category); err != nil {
+ if err := p.state.DB.PutEmojiCategory(ctx, category); err != nil {
err = fmt.Errorf("GetOrCreateEmojiCategory: error putting new emoji category in the database: %s", err)
return nil, err
}
@@ -319,7 +319,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,
return nil, gtserror.NewErrorBadRequest(err, err.Error())
}
- maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, *shortcode, "")
+ maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, *shortcode, "")
if maybeExisting != nil {
err := fmt.Errorf("emojiUpdateCopy: emoji %s could not be copied, emoji with shortcode %s already exists on this instance", emoji.ID, *shortcode)
return nil, gtserror.NewErrorConflict(err, err.Error())
@@ -339,7 +339,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,
newEmojiURI := uris.GenerateURIForEmoji(newEmojiID)
data := func(ctx context.Context) (reader io.ReadCloser, fileSize int64, err error) {
- rc, err := p.storage.GetStream(ctx, emoji.ImagePath)
+ rc, err := p.state.Storage.GetStream(ctx, emoji.ImagePath)
return rc, int64(emoji.ImageFileSize), err
}
@@ -386,7 +386,7 @@ func (p *Processor) emojiUpdateDisable(ctx context.Context, emoji *gtsmodel.Emoj
emojiDisabled := true
emoji.Disabled = &emojiDisabled
- updatedEmoji, err := p.db.UpdateEmoji(ctx, emoji, "updated_at", "disabled")
+ updatedEmoji, err := p.state.DB.UpdateEmoji(ctx, emoji, "updated_at", "disabled")
if err != nil {
err = fmt.Errorf("emojiUpdateDisable: error updating emoji %s: %s", emoji.ID, err)
return nil, gtserror.NewErrorInternalError(err)
@@ -443,7 +443,7 @@ func (p *Processor) emojiUpdateModify(ctx context.Context, emoji *gtsmodel.Emoji
}
var err error
- updatedEmoji, err = p.db.UpdateEmoji(ctx, emoji, columns...)
+ updatedEmoji, err = p.state.DB.UpdateEmoji(ctx, emoji, columns...)
if err != nil {
err = fmt.Errorf("emojiUpdateModify: error updating emoji %s: %s", emoji.ID, err)
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/admin/report.go b/internal/processing/admin/report.go
index 3a6028bca..bed97e204 100644
--- a/internal/processing/admin/report.go
+++ b/internal/processing/admin/report.go
@@ -43,7 +43,7 @@ func (p *Processor) ReportsGet(
minID string,
limit int,
) (*apimodel.PageableResponse, gtserror.WithCode) {
- reports, err := p.db.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit)
+ reports, err := p.state.DB.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit)
if err != nil {
if err == db.ErrNoEntries {
return util.EmptyPageableResponse(), nil
@@ -95,7 +95,7 @@ func (p *Processor) ReportsGet(
// ReportGet returns one report, with the given ID.
func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.AdminReport, gtserror.WithCode) {
- report, err := p.db.GetReportByID(ctx, id)
+ report, err := p.state.DB.GetReportByID(ctx, id)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err)
@@ -113,7 +113,7 @@ func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id
// ReportResolve marks a report with the given id as resolved, and stores the provided actionTakenComment (if not null).
func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account, id string, actionTakenComment *string) (*apimodel.AdminReport, gtserror.WithCode) {
- report, err := p.db.GetReportByID(ctx, id)
+ report, err := p.state.DB.GetReportByID(ctx, id)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err)
@@ -134,7 +134,7 @@ func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "action_taken")
}
- updatedReport, err := p.db.UpdateReport(ctx, report, columns...)
+ updatedReport, err := p.state.DB.UpdateReport(ctx, report, columns...)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/app.go b/internal/processing/app.go
index f2a938b22..e4cda5a43 100644
--- a/internal/processing/app.go
+++ b/internal/processing/app.go
@@ -62,7 +62,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
}
// chuck it in the db
- if err := p.db.Put(ctx, app); err != nil {
+ if err := p.state.DB.Put(ctx, app); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -76,7 +76,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
}
// chuck it in the db
- if err := p.db.Put(ctx, oc); err != nil {
+ if err := p.state.DB.Put(ctx, oc); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go
index 6dd9c3de9..754954f02 100644
--- a/internal/processing/blocks.go
+++ b/internal/processing/blocks.go
@@ -31,7 +31,7 @@ import (
)
func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
- accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
+ accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
if err != nil {
if err == db.ErrNoEntries {
// there are just no entries
diff --git a/internal/processing/fedi/collections.go b/internal/processing/fedi/collections.go
index 78a65bebe..627511c3b 100644
--- a/internal/processing/fedi/collections.go
+++ b/internal/processing/fedi/collections.go
@@ -84,8 +84,8 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag
// scenario 2 -- get the requested page
// limit pages to 30 entries per page
- publicStatuses, err := p.db.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true)
- if err != nil && err != db.ErrNoEntries {
+ publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -161,7 +161,7 @@ func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUsername
return nil, errWithCode
}
- statuses, err := p.db.GetAccountPinnedStatuses(ctx, requestedAccount.ID)
+ statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, requestedAccount.ID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/fedi/common.go b/internal/processing/fedi/common.go
index 37c604ded..a2c7f9b37 100644
--- a/internal/processing/fedi/common.go
+++ b/internal/processing/fedi/common.go
@@ -29,7 +29,7 @@ import (
)
func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) {
- requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "")
+ requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
if err != nil {
errWithCode = gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
return
@@ -46,7 +46,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)
return
}
- blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
errWithCode = gtserror.NewErrorInternalError(err)
return
diff --git a/internal/processing/fedi/emoji.go b/internal/processing/fedi/emoji.go
index 0b1dd3440..b2618ca13 100644
--- a/internal/processing/fedi/emoji.go
+++ b/internal/processing/fedi/emoji.go
@@ -32,7 +32,7 @@ func (p *Processor) EmojiGet(ctx context.Context, requestedEmojiID string) (inte
return nil, errWithCode
}
- requestedEmoji, err := p.db.GetEmojiByID(ctx, requestedEmojiID)
+ requestedEmoji, err := p.state.DB.GetEmojiByID(ctx, requestedEmojiID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting emoji with id %s: %s", requestedEmojiID, err))
}
diff --git a/internal/processing/fedi/fedi.go b/internal/processing/fedi/fedi.go
index e72d037f5..c8f78c5a6 100644
--- a/internal/processing/fedi/fedi.go
+++ b/internal/processing/fedi/fedi.go
@@ -19,25 +19,25 @@
package fedi
import (
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
type Processor struct {
- db db.DB
+ state *state.State
federator federation.Federator
tc typeutils.TypeConverter
filter visibility.Filter
}
// New returns a new fedi processor.
-func New(db db.DB, tc typeutils.TypeConverter, federator federation.Federator) Processor {
+func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator) Processor {
return Processor{
- db: db,
+ state: state,
federator: federator,
tc: tc,
- filter: visibility.NewFilter(db),
+ filter: visibility.NewFilter(state.DB),
}
}
diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go
index fbadcb290..60ebb3c84 100644
--- a/internal/processing/fedi/status.go
+++ b/internal/processing/fedi/status.go
@@ -36,7 +36,7 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req
return nil, errWithCode
}
- status, err := p.db.GetStatusByID(ctx, requestedStatusID)
+ status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
@@ -74,7 +74,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
return nil, errWithCode
}
- status, err := p.db.GetStatusByID(ctx, requestedStatusID)
+ status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
@@ -125,7 +125,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
default:
// scenario 3
// get immediate children
- replies, err := p.db.GetStatusChildren(ctx, status, true, minID)
+ replies, err := p.state.DB.GetStatusChildren(ctx, status, true, minID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/fedi/user.go b/internal/processing/fedi/user.go
index 899d063d1..35e756e57 100644
--- a/internal/processing/fedi/user.go
+++ b/internal/processing/fedi/user.go
@@ -34,7 +34,7 @@ import (
// before returning a JSON serializable interface to the caller.
func (p *Processor) UserGet(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// Get the instance-local account the request is referring to.
- requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "")
+ requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -63,7 +63,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque
return nil, gtserror.NewErrorUnauthorized(err)
}
- blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/fedi/wellknown.go b/internal/processing/fedi/wellknown.go
index 75ed34ec2..6f113ac5d 100644
--- a/internal/processing/fedi/wellknown.go
+++ b/internal/processing/fedi/wellknown.go
@@ -64,12 +64,12 @@ func (p *Processor) NodeInfoRelGet(ctx context.Context) (*apimodel.WellKnownResp
func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserror.WithCode) {
host := config.GetHost()
- userCount, err := p.db.CountInstanceUsers(ctx, host)
+ userCount, err := p.state.DB.CountInstanceUsers(ctx, host)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- postCount, err := p.db.CountInstanceStatuses(ctx, host)
+ postCount, err := p.state.DB.CountInstanceStatuses(ctx, host)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -99,7 +99,7 @@ func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserr
// WebfingerGet handles the GET for a webfinger resource. Most commonly, it will be used for returning account lookups.
func (p *Processor) WebfingerGet(ctx context.Context, requestedUsername string) (*apimodel.WellKnownResponse, gtserror.WithCode) {
// Get the local account the request is referring to.
- requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "")
+ requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go
index 1f1b7f3c2..9bd13cc0b 100644
--- a/internal/processing/followrequest.go
+++ b/internal/processing/followrequest.go
@@ -30,7 +30,7 @@ import (
)
func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
- frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID)
+ frs, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)
if err != nil {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err)
@@ -40,7 +40,7 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
accts := []apimodel.Account{}
for _, fr := range frs {
if fr.Account == nil {
- frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID)
+ frAcct, err := p.state.DB.GetAccountByID(ctx, fr.AccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -57,13 +57,13 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
}
func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
- follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
+ follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
if follow.Account == nil {
- followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID)
+ followAccount, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -71,14 +71,14 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
}
if follow.TargetAccount == nil {
- followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID)
+ followTargetAccount, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
follow.TargetAccount = followTargetAccount
}
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept,
GTSModel: follow,
@@ -86,7 +86,7 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
TargetAccount: follow.TargetAccount,
})
- gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)
+ gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -100,13 +100,13 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
}
func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
- followRequest, err := p.db.RejectFollowRequest(ctx, accountID, auth.Account.ID)
+ followRequest, err := p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
if followRequest.Account == nil {
- a, err := p.db.GetAccountByID(ctx, followRequest.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -114,14 +114,14 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a
}
if followRequest.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
followRequest.TargetAccount = a
}
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityReject,
GTSModel: followRequest,
@@ -129,7 +129,7 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a
TargetAccount: followRequest.TargetAccount,
})
- gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)
+ gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go
index 701f425f6..209a27105 100644
--- a/internal/processing/fromclientapi.go
+++ b/internal/processing/fromclientapi.go
@@ -143,7 +143,7 @@ func (p *Processor) processCreateAccountFromClientAPI(ctx context.Context, clien
}
// get the user this account belongs to
- user, err := p.db.GetUserByAccountID(ctx, account.ID)
+ user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)
if err != nil {
return err
}
@@ -293,7 +293,7 @@ func (p *Processor) processUndoAnnounceFromClientAPI(ctx context.Context, client
return errors.New("undo was not parseable as *gtsmodel.Status")
}
- if err := p.db.DeleteStatusByID(ctx, boost.ID); err != nil {
+ if err := p.state.DB.DeleteStatusByID(ctx, boost.ID); err != nil {
return err
}
@@ -422,7 +422,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
}
if status.Account == nil {
- statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
+ statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
}
@@ -455,7 +455,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
func (p *Processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
- statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
+ statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err)
}
@@ -642,7 +642,7 @@ func (p *Processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Stat
func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow) error {
if follow.Account == nil {
- a, err := p.db.GetAccountByID(ctx, follow.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil {
return err
}
@@ -651,7 +651,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts
originAccount := follow.Account
if follow.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, follow.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil {
return err
}
@@ -715,7 +715,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts
func (p *Processor) federateRejectFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {
if followRequest.Account == nil {
- a, err := p.db.GetAccountByID(ctx, followRequest.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil {
return err
}
@@ -724,7 +724,7 @@ func (p *Processor) federateRejectFollowRequest(ctx context.Context, followReque
originAccount := followRequest.Account
if followRequest.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil {
return err
}
@@ -844,7 +844,7 @@ func (p *Processor) federateAccountUpdate(ctx context.Context, updatedAccount *g
func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
- blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
+ blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)
if err != nil {
return fmt.Errorf("federateBlock: error getting block account from database: %s", err)
}
@@ -852,7 +852,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
}
if block.TargetAccount == nil {
- blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
+ blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)
if err != nil {
return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)
}
@@ -880,7 +880,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
- blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
+ blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)
if err != nil {
return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)
}
@@ -888,7 +888,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)
}
if block.TargetAccount == nil {
- blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
+ blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)
if err != nil {
return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)
}
@@ -934,7 +934,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)
func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) error {
if report.TargetAccount == nil {
- reportTargetAccount, err := p.db.GetAccountByID(ctx, report.TargetAccountID)
+ reportTargetAccount, err := p.state.DB.GetAccountByID(ctx, report.TargetAccountID)
if err != nil {
return fmt.Errorf("federateReport: error getting report target account from database: %w", err)
}
@@ -942,7 +942,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)
}
if len(report.StatusIDs) > 0 && len(report.Statuses) == 0 {
- statuses, err := p.db.GetStatuses(ctx, report.StatusIDs)
+ statuses, err := p.state.DB.GetStatuses(ctx, report.StatusIDs)
if err != nil {
return fmt.Errorf("federateReport: error getting report statuses from database: %w", err)
}
@@ -966,7 +966,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)
// deliver the flag using the outbox of the
// instance account to anonymize the report
- instanceAccount, err := p.db.GetInstanceAccount(ctx, "")
+ instanceAccount, err := p.state.DB.GetInstanceAccount(ctx, "")
if err != nil {
return fmt.Errorf("federateReport: error getting instance account: %w", err)
}
diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go
index 3e4c62c6c..f9e732732 100644
--- a/internal/processing/fromcommon.go
+++ b/internal/processing/fromcommon.go
@@ -38,7 +38,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
if status.Mentions == nil {
// there are mentions but they're not fully populated on the status yet so do this
- menchies, err := p.db.GetMentions(ctx, status.MentionIDs)
+ menchies, err := p.state.DB.GetMentions(ctx, status.MentionIDs)
if err != nil {
return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)
}
@@ -49,7 +49,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
for _, m := range status.Mentions {
// make sure this is a local account, otherwise we don't need to create a notification for it
if m.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, m.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, m.TargetAccountID)
if err != nil {
// we don't have the account or there's been an error
return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err)
@@ -62,7 +62,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
}
// make sure a notif doesn't already exist for this mention
- if err := p.db.GetWhere(ctx, []db.Where{
+ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationMention},
{Key: "target_account_id", Value: m.TargetAccountID},
{Key: "origin_account_id", Value: m.OriginAccountID},
@@ -87,7 +87,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
Status: status,
}
- if err := p.db.Put(ctx, notif); err != nil {
+ if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)
}
@@ -108,7 +108,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {
// make sure we have the target account pinned on the follow request
if followRequest.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil {
return err
}
@@ -129,7 +129,7 @@ func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsm
OriginAccountID: followRequest.AccountID,
}
- if err := p.db.Put(ctx, notif); err != nil {
+ if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)
}
@@ -153,7 +153,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
}
// first remove the follow request notification
- if err := p.db.DeleteWhere(ctx, []db.Where{
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationFollowRequest},
{Key: "target_account_id", Value: follow.TargetAccountID},
{Key: "origin_account_id", Value: follow.AccountID},
@@ -170,7 +170,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
OriginAccountID: follow.AccountID,
OriginAccount: follow.Account,
}
- if err := p.db.Put(ctx, notif); err != nil {
+ if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)
}
@@ -194,7 +194,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
}
if fave.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, fave.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
if err != nil {
return err
}
@@ -218,7 +218,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
Status: fave.Status,
}
- if err := p.db.Put(ctx, notif); err != nil {
+ if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFave: error putting notification in database: %s", err)
}
@@ -242,7 +242,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
}
if status.BoostOf == nil {
- boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID)
+ boostedStatus, err := p.state.DB.GetStatusByID(ctx, status.BoostOfID)
if err != nil {
return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err)
}
@@ -250,7 +250,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
}
if status.BoostOfAccount == nil {
- boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID)
+ boostedAcct, err := p.state.DB.GetAccountByID(ctx, status.BoostOfAccountID)
if err != nil {
return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err)
}
@@ -269,7 +269,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
}
// make sure a notif doesn't already exist for this announce
- err := p.db.GetWhere(ctx, []db.Where{
+ err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationReblog},
{Key: "target_account_id", Value: status.BoostOfAccountID},
{Key: "origin_account_id", Value: status.AccountID},
@@ -292,7 +292,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
Status: status,
}
- if err := p.db.Put(ctx, notif); err != nil {
+ if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)
}
@@ -314,7 +314,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {
// make sure the author account is pinned onto the status
if status.Account == nil {
- a, err := p.db.GetAccountByID(ctx, status.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)
}
@@ -322,7 +322,7 @@ func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status)
}
// get local followers of the account that posted the status
- follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true)
+ follows, err := p.state.DB.GetAccountFollowedBy(ctx, status.AccountID, true)
if err != nil {
return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)
}
@@ -374,7 +374,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, status *gtsmod
defer wg.Done()
// get the timeline owner account
- timelineAccount, err := p.db.GetAccountByID(ctx, accountID)
+ timelineAccount, err := p.state.DB.GetAccountByID(ctx, accountID)
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err)
return
@@ -446,28 +446,28 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
// delete all mention entries generated by this status
for _, m := range statusToDelete.MentionIDs {
- if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {
+ if err := p.state.DB.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {
return err
}
}
// delete all notification entries generated by this status
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
return err
}
// delete all bookmarks that point to this status
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
return err
}
// delete all boosts for this status + remove them from timelines
- if boosts, err := p.db.GetStatusReblogs(ctx, statusToDelete); err == nil {
+ if boosts, err := p.state.DB.GetStatusReblogs(ctx, statusToDelete); err == nil {
for _, b := range boosts {
if err := p.deleteStatusFromTimelines(ctx, b); err != nil {
return err
}
- if err := p.db.DeleteStatusByID(ctx, b.ID); err != nil {
+ if err := p.state.DB.DeleteStatusByID(ctx, b.ID); err != nil {
return err
}
}
@@ -479,7 +479,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
}
// delete the status itself
- if err := p.db.DeleteStatusByID(ctx, statusToDelete.ID); err != nil {
+ if err := p.state.DB.DeleteStatusByID(ctx, statusToDelete.ID); err != nil {
return err
}
diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go
index eea3c529d..afddedf93 100644
--- a/internal/processing/fromfederator.go
+++ b/internal/processing/fromfederator.go
@@ -139,7 +139,7 @@ func (p *Processor) processCreateStatusFromFederator(ctx context.Context, federa
// make sure the account is pinned
if status.Account == nil {
- a, err := p.db.GetAccountByID(ctx, status.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil {
return err
}
@@ -185,7 +185,7 @@ func (p *Processor) processCreateFaveFromFederator(ctx context.Context, federato
// make sure the account is pinned
if incomingFave.Account == nil {
- a, err := p.db.GetAccountByID(ctx, incomingFave.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, incomingFave.AccountID)
if err != nil {
return err
}
@@ -227,7 +227,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
// make sure the account is pinned
if followRequest.Account == nil {
- a, err := p.db.GetAccountByID(ctx, followRequest.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil {
return err
}
@@ -254,7 +254,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
}
if followRequest.TargetAccount == nil {
- a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil {
return err
}
@@ -267,7 +267,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
}
// if the target account isn't locked, we should already accept the follow and notify about the new follower instead
- follow, err := p.db.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID)
+ follow, err := p.state.DB.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID)
if err != nil {
return err
}
@@ -288,7 +288,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede
// make sure the account is pinned
if incomingAnnounce.Account == nil {
- a, err := p.db.GetAccountByID(ctx, incomingAnnounce.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, incomingAnnounce.AccountID)
if err != nil {
return err
}
@@ -324,7 +324,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede
}
incomingAnnounce.ID = incomingAnnounceID
- if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil {
+ if err := p.state.DB.PutStatus(ctx, incomingAnnounce); err != nil {
return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
}
diff --git a/internal/processing/fromfederator_test.go b/internal/processing/fromfederator_test.go
index 6913b22af..d8f8ad6e1 100644
--- a/internal/processing/fromfederator_test.go
+++ b/internal/processing/fromfederator_test.go
@@ -344,7 +344,6 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {
suite.NoError(err)
// now they are mufos!
-
err = suite.processor.ProcessFromFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityDelete,
diff --git a/internal/processing/instance.go b/internal/processing/instance.go
index c3dc4dcea..3ca807af3 100644
--- a/internal/processing/instance.go
+++ b/internal/processing/instance.go
@@ -35,7 +35,7 @@ import (
func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) {
i := >smodel.Instance{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil {
return nil, err
}
return i, nil
@@ -73,7 +73,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
domains := []*apimodel.Domain{}
if includeOpen {
- instances, err := p.db.GetInstancePeers(ctx, false)
+ instances, err := p.state.DB.GetInstancePeers(ctx, false)
if err != nil && err != db.ErrNoEntries {
err = fmt.Errorf("error selecting instance peers: %s", err)
return nil, gtserror.NewErrorInternalError(err)
@@ -87,7 +87,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
if includeSuspended {
domainBlocks := []*gtsmodel.DomainBlock{}
- if err := p.db.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries {
+ if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -124,12 +124,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
// fetch the instance entry from the db for processing
i := >smodel.Instance{}
host := config.GetHost()
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err))
}
// fetch the instance account from the db for processing
- ia, err := p.db.GetInstanceAccount(ctx, "")
+ ia, err := p.state.DB.GetInstanceAccount(ctx, "")
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", host, err))
}
@@ -148,12 +148,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
// validate & update site contact account if it's set on the form
if form.ContactUsername != nil {
// make sure the account with the given username exists in the db
- contactAccount, err := p.db.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "")
+ contactAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "")
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
}
// make sure it has a user associated with it
- contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID)
+ contactUser, err := p.state.DB.GetUserByAccountID(ctx, contactAccount.ID)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
}
@@ -233,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
} else if form.AvatarDescription != nil && ia.AvatarMediaAttachment != nil {
// process just the description for the existing avatar
ia.AvatarMediaAttachment.Description = *form.AvatarDescription
- if err := p.db.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance avatar description: %s", err))
}
}
@@ -252,13 +252,13 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
if updateInstanceAccount {
// if either avatar or header is updated, we need
// to update the instance account that stores them
- if err := p.db.UpdateAccount(ctx, ia); err != nil {
+ if err := p.state.DB.UpdateAccount(ctx, ia); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err))
}
}
if len(updatingColumns) != 0 {
- if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))
}
}
diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go
index 6507fcae4..02bd6cd0d 100644
--- a/internal/processing/media/delete.go
+++ b/internal/processing/media/delete.go
@@ -13,7 +13,7 @@ import (
// Delete deletes the media attachment with the given ID, including all files pertaining to that attachment.
func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode {
- attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil {
if err == db.ErrNoEntries {
// attachment already gone
@@ -27,20 +27,20 @@ func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr
// delete the thumbnail from storage
if attachment.Thumbnail.Path != "" {
- if err := p.storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
+ if err := p.state.Storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err))
}
}
// delete the file from storage
if attachment.File.Path != "" {
- if err := p.storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
+ if err := p.state.Storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err))
}
}
// delete the attachment
- if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ if err := p.state.DB.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) {
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
}
diff --git a/internal/processing/media/getemoji.go b/internal/processing/media/getemoji.go
index 4c0ce9930..fba059f60 100644
--- a/internal/processing/media/getemoji.go
+++ b/internal/processing/media/getemoji.go
@@ -31,7 +31,7 @@ import (
// GetCustomEmojis returns a list of all useable local custom emojis stored on this instance.
// 'useable' in this context means visible and picker, and not disabled.
func (p *Processor) GetCustomEmojis(ctx context.Context) ([]*apimodel.Emoji, gtserror.WithCode) {
- emojis, err := p.db.GetUseableEmojis(ctx)
+ emojis, err := p.state.DB.GetUseableEmojis(ctx)
if err != nil {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error retrieving custom emojis: %s", err))
diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go
index 2a4ef2097..f9c6c23c2 100644
--- a/internal/processing/media/getfile.go
+++ b/internal/processing/media/getfile.go
@@ -54,7 +54,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
owningAccountID := form.AccountID
// get the account that owns the media and make sure it's not suspended
- owningAccount, err := p.db.GetAccountByID(ctx, owningAccountID)
+ owningAccount, err := p.state.DB.GetAccountByID(ctx, owningAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", owningAccountID, err))
}
@@ -64,7 +64,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
// make sure the requesting account and the media account don't block each other
if requestingAccount != nil {
- blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err))
}
@@ -117,7 +117,7 @@ func parseSize(s string) (media.Size, error) {
func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount *gtsmodel.Account, wantedMediaID string, owningAccountID string, mediaSize media.Size) (*apimodel.Content, gtserror.WithCode) {
// retrieve attachment from the database and do basic checks on it
- a, err := p.db.GetAttachmentByID(ctx, wantedMediaID)
+ a, err := p.state.DB.GetAttachmentByID(ctx, wantedMediaID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err))
}
@@ -209,7 +209,7 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning
// so this is more reliable than using full size url
imageStaticURL := uris.GenerateURIForAttachment(owningAccountID, string(media.TypeEmoji), string(media.SizeStatic), fileName, "png")
- e, err := p.db.GetEmojiByStaticURL(ctx, imageStaticURL)
+ e, err := p.state.DB.GetEmojiByStaticURL(ctx, imageStaticURL)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", fileName, err))
}
@@ -237,12 +237,12 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning
func (p *Processor) retrieveFromStorage(ctx context.Context, storagePath string, content *apimodel.Content) (*apimodel.Content, gtserror.WithCode) {
// If running on S3 storage with proxying disabled then
// just fetch a pre-signed URL instead of serving the content.
- if url := p.storage.URL(ctx, storagePath); url != nil {
+ if url := p.state.Storage.URL(ctx, storagePath); url != nil {
content.URL = url
return content, nil
}
- reader, err := p.storage.GetStream(ctx, storagePath)
+ reader, err := p.state.Storage.GetStream(ctx, storagePath)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error retrieving from storage: %s", err))
}
diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go
index 03d5ba770..dad6ac538 100644
--- a/internal/processing/media/getmedia.go
+++ b/internal/processing/media/getmedia.go
@@ -30,7 +30,7 @@ import (
)
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
- attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
diff --git a/internal/processing/media/media.go b/internal/processing/media/media.go
index ca95e276f..51585102a 100644
--- a/internal/processing/media/media.go
+++ b/internal/processing/media/media.go
@@ -19,28 +19,25 @@
package media
import (
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
)
type Processor struct {
+ state *state.State
tc typeutils.TypeConverter
mediaManager media.Manager
transportController transport.Controller
- storage *storage.Driver
- db db.DB
}
// New returns a new media processor.
-func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver) Processor {
+func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {
return Processor{
+ state: state,
tc: tc,
mediaManager: mediaManager,
transportController: transportController,
- storage: storage,
- db: db,
}
}
diff --git a/internal/processing/media/media_test.go b/internal/processing/media/media_test.go
index 1d223a66c..e706dbd7a 100644
--- a/internal/processing/media/media_test.go
+++ b/internal/processing/media/media_test.go
@@ -20,12 +20,11 @@ package media_test
import (
"github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
@@ -38,6 +37,7 @@ type MediaStandardTestSuite struct {
db db.DB
tc typeutils.TypeConverter
storage *storage.Driver
+ state state.State
mediaManager media.Manager
transportController transport.Controller
@@ -67,15 +67,19 @@ func (suite *MediaStandardTestSuite) SetupSuite() {
}
func (suite *MediaStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1))
- suite.mediaProcessor = mediaprocessing.New(suite.db, suite.tc, suite.mediaManager, suite.transportController, suite.storage)
+ suite.state.Storage = suite.storage
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
+ suite.mediaProcessor = mediaprocessing.New(&suite.state, suite.tc, suite.mediaManager, suite.transportController)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
}
diff --git a/internal/processing/media/unattach.go b/internal/processing/media/unattach.go
index 816b5134e..7c6f7dbac 100644
--- a/internal/processing/media/unattach.go
+++ b/internal/processing/media/unattach.go
@@ -33,7 +33,7 @@ import (
// Unattach unattaches the media attachment with the given ID from any statuses it was attached to, making it available
// for reattachment again.
func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
- attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
@@ -49,7 +49,7 @@ func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, med
attachment.UpdatedAt = time.Now()
attachment.StatusID = ""
- if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err))
}
diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go
index c03df705b..cf49168f0 100644
--- a/internal/processing/media/update.go
+++ b/internal/processing/media/update.go
@@ -32,7 +32,7 @@ import (
// Update updates a media attachment with the given id, using the provided form parameters.
func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
- attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
@@ -62,7 +62,7 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, media
updatingColumns = append(updatingColumns, "focus_x", "focus_y")
}
- if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err))
}
diff --git a/internal/processing/notification.go b/internal/processing/notification.go
index 05d0e82ee..57100e743 100644
--- a/internal/processing/notification.go
+++ b/internal/processing/notification.go
@@ -29,7 +29,7 @@ import (
)
func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) {
- notifs, err := p.db.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
+ notifs, err := p.state.DB.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -72,7 +72,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ex
}
func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode {
- err := p.db.ClearNotifications(ctx, authed.Account.ID)
+ err := p.state.DB.ClearNotifications(ctx, authed.Account.ID)
if err != nil {
return gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/processor.go b/internal/processing/processor.go
index 07fcdb8b3..bb75aab76 100644
--- a/internal/processing/processor.go
+++ b/internal/processing/processor.go
@@ -19,10 +19,11 @@
package processing
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
+ "context"
+
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
mm "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
@@ -34,23 +35,19 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/processing/stream"
"github.com/superseriousbusiness/gotosocial/internal/processing/user"
- "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
type Processor struct {
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
- fedWorker *concurrency.WorkerPool[messages.FromFederator]
-
federator federation.Federator
tc typeutils.TypeConverter
oauthServer oauth.Server
mediaManager mm.Manager
- storage *storage.Driver
statusTimelines timeline.Manager
- db db.DB
+ state *state.State
filter visibility.Filter
/*
@@ -105,76 +102,65 @@ func NewProcessor(
federator federation.Federator,
oauthServer oauth.Server,
mediaManager mm.Manager,
- storage *storage.Driver,
- db db.DB,
+ state *state.State,
emailSender email.Sender,
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
- fedWorker *concurrency.WorkerPool[messages.FromFederator],
) *Processor {
- parseMentionFunc := GetParseMentionFunc(db, federator)
+ parseMentionFunc := GetParseMentionFunc(state.DB, federator)
- filter := visibility.NewFilter(db)
+ filter := visibility.NewFilter(state.DB)
- return &Processor{
- clientWorker: clientWorker,
- fedWorker: fedWorker,
-
- federator: federator,
- tc: tc,
- oauthServer: oauthServer,
- mediaManager: mediaManager,
- storage: storage,
- statusTimelines: timeline.NewManager(StatusGrabFunction(db), StatusFilterFunction(db, filter), StatusPrepareFunction(db, tc), StatusSkipInsertFunction()),
- db: db,
- filter: filter,
-
- // sub processors
- account: account.New(db, tc, mediaManager, oauthServer, clientWorker, federator, parseMentionFunc),
- admin: admin.New(db, tc, mediaManager, federator.TransportController(), storage, clientWorker),
- fedi: fedi.New(db, tc, federator),
- media: media.New(db, tc, mediaManager, federator.TransportController(), storage),
- report: report.New(db, tc, clientWorker),
- status: status.New(db, tc, clientWorker, parseMentionFunc),
- stream: stream.New(db, oauthServer),
- user: user.New(db, emailSender),
+ processor := &Processor{
+ federator: federator,
+ tc: tc,
+ oauthServer: oauthServer,
+ mediaManager: mediaManager,
+ statusTimelines: timeline.NewManager(
+ StatusGrabFunction(state.DB),
+ StatusFilterFunction(state.DB, filter),
+ StatusPrepareFunction(state.DB, tc),
+ StatusSkipInsertFunction(),
+ ),
+ state: state,
+ filter: filter,
}
+
+ // sub processors
+ processor.account = account.New(state, tc, mediaManager, oauthServer, federator, parseMentionFunc)
+ processor.admin = admin.New(state, tc, mediaManager, federator.TransportController())
+ processor.fedi = fedi.New(state, tc, federator)
+ processor.media = media.New(state, tc, mediaManager, federator.TransportController())
+ processor.report = report.New(state, tc)
+ processor.status = status.New(state, tc, parseMentionFunc)
+ processor.stream = stream.New(state, oauthServer)
+ processor.user = user.New(state, emailSender)
+
+ return processor
}
-// Start starts the Processor, reading from its channels and passing messages back and forth.
+func (p *Processor) EnqueueClientAPI(ctx context.Context, msg messages.FromClientAPI) {
+ log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing client API")
+ _ = p.state.Workers.ClientAPI.MustEnqueueCtx(ctx, func(ctx context.Context) {
+ if err := p.ProcessFromClientAPI(ctx, msg); err != nil {
+ log.Errorf(ctx, "error processing client API message: %v", err)
+ }
+ })
+}
+
+func (p *Processor) EnqueueFederator(ctx context.Context, msg messages.FromFederator) {
+ log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing federator")
+ _ = p.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) {
+ if err := p.ProcessFromFederator(ctx, msg); err != nil {
+ log.Errorf(ctx, "error processing federator message: %v", err)
+ }
+ })
+}
+
+// Start starts the Processor.
func (p *Processor) Start() error {
- // Setup and start the client API worker pool
- p.clientWorker.SetProcessor(p.ProcessFromClientAPI)
- if err := p.clientWorker.Start(); err != nil {
- return err
- }
-
- // Setup and start the federator worker pool
- p.fedWorker.SetProcessor(p.ProcessFromFederator)
- if err := p.fedWorker.Start(); err != nil {
- return err
- }
-
- // Start status timelines
- if err := p.statusTimelines.Start(); err != nil {
- return err
- }
-
- return nil
+ return p.statusTimelines.Start()
}
-// Stop stops the processor cleanly, finishing handling any remaining messages before closing down.
+// Stop stops the processor cleanly.
func (p *Processor) Stop() error {
- if err := p.clientWorker.Stop(); err != nil {
- return err
- }
-
- if err := p.fedWorker.Stop(); err != nil {
- return err
- }
-
- if err := p.statusTimelines.Stop(); err != nil {
- return err
- }
-
- return nil
+ return p.statusTimelines.Stop()
}
diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go
index 44857cb47..d8da87bcc 100644
--- a/internal/processing/processor_test.go
+++ b/internal/processing/processor_test.go
@@ -20,15 +20,14 @@ package processing_test
import (
"github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
@@ -40,6 +39,7 @@ type ProcessingStandardTestSuite struct {
suite.Suite
db db.DB
storage *storage.Driver
+ state state.State
mediaManager media.Manager
typeconverter typeutils.TypeConverter
httpClient *testrig.MockHTTPClient
@@ -86,25 +86,29 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() {
}
func (suite *ProcessingStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
suite.typeconverter = testrig.NewTestTypeConverter(suite.db)
suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media")
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- suite.transportController = testrig.NewTestTransportController(suite.httpClient, suite.db, fedWorker)
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker)
+ suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient)
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.emailSender = testrig.NewEmailSender("../../web/template/", nil)
- suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
+ suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, &suite.state, suite.emailSender)
+ suite.state.Workers.EnqueueClientAPI = suite.processor.EnqueueClientAPI
+ suite.state.Workers.EnqueueFederator = suite.processor.EnqueueFederator
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
@@ -119,4 +123,5 @@ func (suite *ProcessingStandardTestSuite) TearDownTest() {
if err := suite.processor.Stop(); err != nil {
panic(err)
}
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/processing/report/create.go b/internal/processing/report/create.go
index 726d11666..e0918554e 100644
--- a/internal/processing/report/create.go
+++ b/internal/processing/report/create.go
@@ -41,7 +41,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
}
// validate + fetch target account
- targetAccount, err := p.db.GetAccountByID(ctx, form.AccountID)
+ targetAccount, err := p.state.DB.GetAccountByID(ctx, form.AccountID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("account with ID %s does not exist", form.AccountID)
@@ -52,7 +52,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
}
// fetch statuses by IDs given in the report form (noop if no statuses given)
- statuses, err := p.db.GetStatuses(ctx, form.StatusIDs)
+ statuses, err := p.state.DB.GetStatuses(ctx, form.StatusIDs)
if err != nil {
err = fmt.Errorf("db error fetching report target statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
@@ -79,11 +79,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
Forwarded: &form.Forward,
}
- if err := p.db.PutReport(ctx, report); err != nil {
+ if err := p.state.DB.PutReport(ctx, report); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityFlag,
GTSModel: report,
diff --git a/internal/processing/report/get.go b/internal/processing/report/get.go
index af2079b8a..0348c397c 100644
--- a/internal/processing/report/get.go
+++ b/internal/processing/report/get.go
@@ -32,7 +32,7 @@ import (
// Get returns the user view of a moderation report, with the given id.
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.Report, gtserror.WithCode) {
- report, err := p.db.GetReportByID(ctx, id)
+ report, err := p.state.DB.GetReportByID(ctx, id)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err)
@@ -64,7 +64,7 @@ func (p *Processor) GetMultiple(
minID string,
limit int,
) (*apimodel.PageableResponse, gtserror.WithCode) {
- reports, err := p.db.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit)
+ reports, err := p.state.DB.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit)
if err != nil {
if err == db.ErrNoEntries {
return util.EmptyPageableResponse(), nil
diff --git a/internal/processing/report/report.go b/internal/processing/report/report.go
index b5f4b301e..bc634af2e 100644
--- a/internal/processing/report/report.go
+++ b/internal/processing/report/report.go
@@ -19,22 +19,18 @@
package report
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
)
type Processor struct {
- db db.DB
- tc typeutils.TypeConverter
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
+ state *state.State
+ tc typeutils.TypeConverter
}
-func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
+func New(state *state.State, tc typeutils.TypeConverter) Processor {
return Processor{
- tc: tc,
- db: db,
- clientWorker: clientWorker,
+ state: state,
+ tc: tc,
}
}
diff --git a/internal/processing/search.go b/internal/processing/search.go
index 05a1fe353..c5592fffd 100644
--- a/internal/processing/search.go
+++ b/internal/processing/search.go
@@ -88,7 +88,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
if username, domain, err := util.ExtractNamestringParts(maybeNamestring); err == nil {
l.Trace("search term is a mention, looking it up...")
- blocked, err := p.db.IsDomainBlocked(ctx, domain)
+ blocked, err := p.state.DB.IsDomainBlocked(ctx, domain)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))
}
@@ -120,7 +120,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
if uri, err := url.Parse(query); err == nil {
if uri.Scheme == "https" || uri.Scheme == "http" {
l.Trace("search term is a uri, looking it up...")
- blocked, err := p.db.IsURIBlocked(ctx, uri)
+ blocked, err := p.state.DB.IsURIBlocked(ctx, uri)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))
}
@@ -178,7 +178,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
*/
for _, foundAccount := range foundAccounts {
// make sure there's no block in either direction between the account and the requester
- blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)
if err != nil {
err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err)
return nil, gtserror.NewErrorInternalError(err)
@@ -246,14 +246,14 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth,
)
// Search the database for existing account with ID URI.
- account, err = p.db.GetAccountByURI(ctx, uriStr)
+ account, err = p.state.DB.GetAccountByURI(ctx, uriStr)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err)
}
if account == nil {
// Else, search the database for existing by ID URL.
- account, err = p.db.GetAccountByURL(ctx, uriStr)
+ account, err = p.state.DB.GetAccountByURL(ctx, uriStr)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err)
@@ -281,7 +281,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o
}
// Search the database for existing account with USERNAME@DOMAIN
- account, err := p.db.GetAccountByUsernameDomain(ctx, username, domain)
+ account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, domain)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("searchAccountByUsernameDomain: error checking database for account %s@%s: %w", username, domain, err)
diff --git a/internal/processing/status/bookmark.go b/internal/processing/status/bookmark.go
index dde31ea7d..cf3787da2 100644
--- a/internal/processing/status/bookmark.go
+++ b/internal/processing/status/bookmark.go
@@ -32,7 +32,7 @@ import (
// BookmarkCreate adds a bookmark for the requestingAccount, targeting the given status (no-op if bookmark already exists).
func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -50,7 +50,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
// first check if the status is already bookmarked, if so we don't need to do anything
newBookmark := true
gtsBookmark := >smodel.StatusBookmark{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
// we already have a bookmark for this status
newBookmark = false
}
@@ -67,7 +67,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
Status: targetStatus,
}
- if err := p.db.Put(ctx, gtsBookmark); err != nil {
+ if err := p.state.DB.Put(ctx, gtsBookmark); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting bookmark in database: %s", err))
}
}
@@ -83,7 +83,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
// BookmarkRemove removes a bookmark for the requesting account, targeting the given status (no-op if bookmark doesn't exist).
func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -101,13 +101,13 @@ func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmo
// first check if the status is actually bookmarked
toUnbookmark := false
gtsBookmark := >smodel.StatusBookmark{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
// we have a bookmark for this status
toUnbookmark = true
}
if toUnbookmark {
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))
}
}
diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go
index 4dfe17019..6756d816c 100644
--- a/internal/processing/status/boost.go
+++ b/internal/processing/status/boost.go
@@ -33,7 +33,7 @@ import (
// BoostCreate processes the boost/reblog of a given status, returning the newly-created boost if all is well.
func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -47,7 +47,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
// boost boosts, and it looks absolutely bizarre in the UI
if targetStatus.BoostOfID != "" {
if targetStatus.BoostOf == nil {
- b, err := p.db.GetStatusByID(ctx, targetStatus.BoostOfID)
+ b, err := p.state.DB.GetStatusByID(ctx, targetStatus.BoostOfID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("couldn't fetch boosted status %s", targetStatus.BoostOfID))
}
@@ -74,12 +74,12 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
boostWrapperStatus.BoostOfAccount = targetStatus.Account
// put the boost in the database
- if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil {
+ if err := p.state.DB.PutStatus(ctx, boostWrapperStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// send it back to the processor for async processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityCreate,
GTSModel: boostWrapperStatus,
@@ -98,7 +98,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
// BoostRemove processes the unboost/unreblog of a given status, returning the status if all is well.
func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -128,7 +128,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
Value: requestingAccount.ID,
},
}
- err = p.db.GetWhere(ctx, where, gtsBoost)
+ err = p.state.DB.GetWhere(ctx, where, gtsBoost)
if err == nil {
// we have a boost
toUnboost = true
@@ -151,7 +151,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
gtsBoost.BoostOf.Account = targetStatus.Account
// send it back to the processor for async processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityUndo,
GTSModel: gtsBoost,
@@ -170,7 +170,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
// StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings.
func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", targetStatusID, err)
if !errors.Is(err, db.ErrNoEntries) {
@@ -181,7 +181,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
if boostOfID := targetStatus.BoostOfID; boostOfID != "" {
// the target status is a boost wrapper, redirect this request to the status it boosts
- boostedStatus, err := p.db.GetStatusByID(ctx, boostOfID)
+ boostedStatus, err := p.state.DB.GetStatusByID(ctx, boostOfID)
if err != nil {
wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", boostOfID, err)
if !errors.Is(err, db.ErrNoEntries) {
@@ -202,7 +202,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
return nil, gtserror.NewErrorNotFound(err)
}
- statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus)
+ statusReblogs, err := p.state.DB.GetStatusReblogs(ctx, targetStatus)
if err != nil {
err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err)
return nil, gtserror.NewErrorNotFound(err)
@@ -211,7 +211,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
// filter account IDs so the user doesn't see accounts they blocked or which blocked them
accountIDs := make([]string, 0, len(statusReblogs))
for _, s := range statusReblogs {
- blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
if err != nil {
err = fmt.Errorf("BoostedBy: error checking blocks: %s", err)
return nil, gtserror.NewErrorNotFound(err)
@@ -226,7 +226,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
// fetch accounts + create their API representations
apiAccounts := make([]*apimodel.Account, 0, len(accountIDs))
for _, accountID := range accountIDs {
- account, err := p.db.GetAccountByID(ctx, accountID)
+ account, err := p.state.DB.GetAccountByID(ctx, accountID)
if err != nil {
wrapped := fmt.Errorf("BoostedBy: error fetching account %s: %s", accountID, err)
if !errors.Is(err, db.ErrNoEntries) {
diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go
index f47c850dd..4e5399469 100644
--- a/internal/processing/status/create.go
+++ b/internal/processing/status/create.go
@@ -61,11 +61,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli
Text: form.Status,
}
- if errWithCode := processReplyToID(ctx, p.db, form, account.ID, newStatus); errWithCode != nil {
+ if errWithCode := processReplyToID(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {
return nil, errWithCode
}
- if errWithCode := processMediaIDs(ctx, p.db, form, account.ID, newStatus); errWithCode != nil {
+ if errWithCode := processMediaIDs(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {
return nil, errWithCode
}
@@ -77,17 +77,17 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli
return nil, gtserror.NewErrorInternalError(err)
}
- if err := processContent(ctx, p.db, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil {
+ if err := processContent(ctx, p.state.DB, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// put the new status in the database
- if err := p.db.PutStatus(ctx, newStatus); err != nil {
+ if err := p.state.DB.PutStatus(ctx, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// send it back to the processor for async processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
GTSModel: newStatus,
diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go
index d3a03aad6..0e9510e08 100644
--- a/internal/processing/status/delete.go
+++ b/internal/processing/status/delete.go
@@ -32,7 +32,7 @@ import (
// Delete processes the delete of a given status, returning the deleted status if the delete goes through.
func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -50,7 +50,7 @@ func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Acco
}
// send the status back to the processor for async processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete,
GTSModel: targetStatus,
diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go
index 3bcb1835f..3025c720d 100644
--- a/internal/processing/status/fave.go
+++ b/internal/processing/status/fave.go
@@ -35,7 +35,7 @@ import (
// FaveCreate processes the faving of a given status, returning the updated status if the fave goes through.
func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -57,7 +57,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.
// first check if the status is already faved, if so we don't need to do anything
newFave := true
gtsFave := >smodel.StatusFave{}
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil {
// we already have a fave for this status
newFave = false
}
@@ -77,12 +77,12 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.
URI: uris.GenerateURIForLike(requestingAccount.Username, thisFaveID),
}
- if err := p.db.Put(ctx, gtsFave); err != nil {
+ if err := p.state.DB.Put(ctx, gtsFave); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting fave in database: %s", err))
}
// send it back to the processor for async processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityLike,
APActivityType: ap.ActivityCreate,
GTSModel: gtsFave,
@@ -102,7 +102,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.
// FaveRemove processes the unfaving of a given status, returning the updated status if the fave goes through.
func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -122,7 +122,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
var toUnfave bool
gtsFave := >smodel.StatusFave{}
- err = p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave)
+ err = p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave)
if err == nil {
// we have a fave
toUnfave = true
@@ -138,12 +138,12 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
if toUnfave {
// we had a fave, so take some action to get rid of it
- if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil {
+ if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))
}
// send it back to the processor for async processing
- p.clientWorker.Queue(messages.FromClientAPI{
+ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityLike,
APActivityType: ap.ActivityUndo,
GTSModel: gtsFave,
@@ -162,7 +162,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
// FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings.
func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -178,7 +178,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
return nil, gtserror.NewErrorNotFound(errors.New("status is not visible"))
}
- statusFaves, err := p.db.GetStatusFaves(ctx, targetStatus)
+ statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing who faved status: %s", err))
}
@@ -186,7 +186,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
// filter the list so the user doesn't see accounts they blocked or which blocked them
filteredAccounts := []*gtsmodel.Account{}
for _, fave := range statusFaves {
- blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true)
+ blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking blocks: %s", err))
}
diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go
index edefeb440..51c384c44 100644
--- a/internal/processing/status/get.go
+++ b/internal/processing/status/get.go
@@ -31,7 +31,7 @@ import (
// Get gets the given status, taking account of privacy settings and blocks etc.
func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -57,7 +57,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account
// ContextGet returns the context (previous and following posts) from the given status ID.
func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -78,7 +78,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
Descendants: []apimodel.Status{},
}
- parents, err := p.db.GetStatusParents(ctx, targetStatus, false)
+ parents, err := p.state.DB.GetStatusParents(ctx, targetStatus, false)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -96,7 +96,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
return context.Ancestors[i].ID < context.Ancestors[j].ID
})
- children, err := p.db.GetStatusChildren(ctx, targetStatus, false, "")
+ children, err := p.state.DB.GetStatusChildren(ctx, targetStatus, false, "")
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go
index 3e50b0c73..6001a147f 100644
--- a/internal/processing/status/pin.go
+++ b/internal/processing/status/pin.go
@@ -39,7 +39,7 @@ const allowedPinnedCount = 10
// - Status is public, unlisted, or followers-only.
// - Status is not a boost.
func (p *Processor) getPinnableStatus(ctx context.Context, targetStatusID string, requestingAccountID string) (*gtsmodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
+ targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil {
err = fmt.Errorf("error fetching status %s: %w", targetStatusID, err)
return nil, gtserror.NewErrorNotFound(err)
@@ -84,7 +84,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A
return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error())
}
- pinnedCount, err := p.db.CountAccountPinned(ctx, requestingAccount.ID)
+ pinnedCount, err := p.state.DB.CountAccountPinned(ctx, requestingAccount.ID)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking number of pinned statuses: %w", err))
}
@@ -95,7 +95,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A
}
targetStatus.PinnedAt = time.Now()
- if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
+ if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error pinning status: %w", err))
}
@@ -126,7 +126,7 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A
if targetStatus.PinnedAt.IsZero() {
targetStatus.PinnedAt = time.Time{}
- if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
+ if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error unpinning status: %w", err))
}
}
diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go
index c91fd85d1..909b06481 100644
--- a/internal/processing/status/status.go
+++ b/internal/processing/status/status.go
@@ -19,32 +19,28 @@
package status
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
type Processor struct {
+ state *state.State
tc typeutils.TypeConverter
- db db.DB
filter visibility.Filter
formatter text.Formatter
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new status processor.
-func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
+func New(state *state.State, tc typeutils.TypeConverter, parseMention gtsmodel.ParseMentionFunc) Processor {
return Processor{
+ state: state,
tc: tc,
- db: db,
- filter: visibility.NewFilter(db),
- formatter: text.NewFormatter(db),
- clientWorker: clientWorker,
+ filter: visibility.NewFilter(state.DB),
+ formatter: text.NewFormatter(state.DB),
parseMention: parseMention,
}
}
diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go
index 272d2c8ea..1b35b69db 100644
--- a/internal/processing/status/status_test.go
+++ b/internal/processing/status/status_test.go
@@ -19,17 +19,14 @@
package status_test
import (
- "context"
-
"github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
@@ -42,9 +39,9 @@ type StatusStandardTestSuite struct {
typeConverter typeutils.TypeConverter
tc transport.Controller
storage *storage.Driver
+ state state.State
mediaManager media.Manager
federator federation.Federator
- clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -74,21 +71,22 @@ func (suite *StatusStandardTestSuite) SetupSuite() {
}
func (suite *StatusStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ testrig.StartWorkers(&suite.state)
+
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
-
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
- suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
- suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker)
+ suite.state.DB = suite.db
+
+ suite.tc = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
suite.storage = testrig.NewInMemoryStorage()
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.federator = testrig.NewTestFederator(suite.db, suite.tc, suite.storage, suite.mediaManager, fedWorker)
- suite.status = status.New(suite.db, suite.typeConverter, suite.clientWorker, processing.GetParseMentionFunc(suite.db, suite.federator))
- suite.clientWorker.SetProcessor(func(ctx context.Context, msg messages.FromClientAPI) error { return nil })
- suite.NoError(suite.clientWorker.Start())
+ suite.state.Storage = suite.storage
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager)
+ suite.status = status.New(&suite.state, suite.typeConverter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
@@ -97,4 +95,5 @@ func (suite *StatusStandardTestSuite) SetupTest() {
func (suite *StatusStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
}
diff --git a/internal/processing/statustimeline.go b/internal/processing/statustimeline.go
index 7c9f36f16..8c8e20316 100644
--- a/internal/processing/statustimeline.go
+++ b/internal/processing/statustimeline.go
@@ -173,7 +173,7 @@ func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, max
}
func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
- statuses, err := p.db.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
+ statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
if err != nil {
if err == db.ErrNoEntries {
// there are just no entries left
@@ -218,7 +218,7 @@ func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, m
}
func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
- statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
+ statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
if err != nil {
if err == db.ErrNoEntries {
// there are just no entries left
@@ -255,7 +255,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
targetAccount := >smodel.Account{}
- if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil {
+ if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
if err == db.ErrNoEntries {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
@@ -288,7 +288,7 @@ func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth,
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
targetAccount := >smodel.Account{}
- if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil {
+ if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
if err == db.ErrNoEntries {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
diff --git a/internal/processing/stream/authorize.go b/internal/processing/stream/authorize.go
index 5f6811db9..a30e6fb33 100644
--- a/internal/processing/stream/authorize.go
+++ b/internal/processing/stream/authorize.go
@@ -41,7 +41,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode
return nil, gtserror.NewErrorUnauthorized(err)
}
- user, err := p.db.GetUserByID(ctx, uid)
+ user, err := p.state.DB.GetUserByID(ctx, uid)
if err != nil {
if err == db.ErrNoEntries {
err := fmt.Errorf("no user found for validated uid %s", uid)
@@ -50,7 +50,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode
return nil, gtserror.NewErrorInternalError(err)
}
- acct, err := p.db.GetAccountByID(ctx, user.AccountID)
+ acct, err := p.state.DB.GetAccountByID(ctx, user.AccountID)
if err != nil {
if err == db.ErrNoEntries {
err := fmt.Errorf("no account found for validated uid %s", uid)
diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go
index 3c38e720a..a10ab2474 100644
--- a/internal/processing/stream/stream.go
+++ b/internal/processing/stream/stream.go
@@ -22,22 +22,21 @@ import (
"errors"
"sync"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
type Processor struct {
- db db.DB
+ state *state.State
oauthServer oauth.Server
- streamMap *sync.Map
+ streamMap sync.Map
}
-func New(db db.DB, oauthServer oauth.Server) Processor {
+func New(state *state.State, oauthServer oauth.Server) Processor {
return Processor{
- db: db,
+ state: state,
oauthServer: oauthServer,
- streamMap: &sync.Map{},
}
}
diff --git a/internal/processing/stream/stream_test.go b/internal/processing/stream/stream_test.go
index 907c7e1d0..9e1eb57f2 100644
--- a/internal/processing/stream/stream_test.go
+++ b/internal/processing/stream/stream_test.go
@@ -24,6 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing/stream"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -33,19 +34,23 @@ type StreamTestSuite struct {
testTokens map[string]*gtsmodel.Token
db db.DB
oauthServer oauth.Server
+ state state.State
streamProcessor stream.Processor
}
func (suite *StreamTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+
testrig.InitTestLog()
testrig.InitTestConfig()
suite.testAccounts = testrig.NewTestAccounts()
suite.testTokens = testrig.NewTestTokens()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
- suite.streamProcessor = stream.New(suite.db, suite.oauthServer)
+ suite.streamProcessor = stream.New(&suite.state, suite.oauthServer)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
diff --git a/internal/processing/user/email.go b/internal/processing/user/email.go
index 349e27f47..c55488954 100644
--- a/internal/processing/user/email.go
+++ b/internal/processing/user/email.go
@@ -56,7 +56,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us
// pull our instance entry from the database so we can greet the user nicely in the email
instance := >smodel.Instance{}
host := config.GetHost()
- if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil {
+ if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil {
return fmt.Errorf("SendConfirmEmail: error getting instance: %s", err)
}
@@ -78,7 +78,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us
user.LastEmailedAt = time.Now()
user.UpdatedAt = time.Now()
- if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err)
}
@@ -92,7 +92,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U
return nil, gtserror.NewErrorNotFound(errors.New("no token provided"))
}
- user, err := p.db.GetUserByConfirmationToken(ctx, token)
+ user, err := p.state.DB.GetUserByConfirmationToken(ctx, token)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err)
@@ -101,7 +101,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U
}
if user.Account == nil {
- a, err := p.db.GetAccountByID(ctx, user.AccountID)
+ a, err := p.state.DB.GetAccountByID(ctx, user.AccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
@@ -129,7 +129,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U
user.ConfirmationToken = ""
user.UpdatedAt = time.Now()
- if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
+ if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/user/password.go b/internal/processing/user/password.go
index 3475e005e..72ef5ffa7 100644
--- a/internal/processing/user/password.go
+++ b/internal/processing/user/password.go
@@ -44,7 +44,7 @@ func (p *Processor) PasswordChange(ctx context.Context, user *gtsmodel.User, old
user.EncryptedPassword = string(newPasswordHash)
- if err := p.db.UpdateUser(ctx, user, "encrypted_password"); err != nil {
+ if err := p.state.DB.UpdateUser(ctx, user, "encrypted_password"); err != nil {
return gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/user/user.go b/internal/processing/user/user.go
index fce628d0c..4fda4c1f6 100644
--- a/internal/processing/user/user.go
+++ b/internal/processing/user/user.go
@@ -19,19 +19,19 @@
package user
import (
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
type Processor struct {
+ state *state.State
emailSender email.Sender
- db db.DB
}
// New returns a new user processor
-func New(db db.DB, emailSender email.Sender) Processor {
+func New(state *state.State, emailSender email.Sender) Processor {
return Processor{
+ state: state,
emailSender: emailSender,
- db: db,
}
}
diff --git a/internal/processing/user/user_test.go b/internal/processing/user/user_test.go
index 83ab5892e..7379b568e 100644
--- a/internal/processing/user/user_test.go
+++ b/internal/processing/user/user_test.go
@@ -24,6 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing/user"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -31,6 +32,7 @@ type UserStandardTestSuite struct {
suite.Suite
emailSender email.Sender
db db.DB
+ state state.State
testUsers map[string]*gtsmodel.User
@@ -40,15 +42,19 @@ type UserStandardTestSuite struct {
}
func (suite *UserStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
+
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
suite.testUsers = testrig.NewTestUsers()
- suite.user = user.New(suite.db, suite.emailSender)
+ suite.user = user.New(&suite.state, suite.emailSender)
testrig.StandardDBSetup(suite.db, nil)
}
diff --git a/internal/text/formatter_test.go b/internal/text/formatter_test.go
index 32ae74488..304a538fc 100644
--- a/internal/text/formatter_test.go
+++ b/internal/text/formatter_test.go
@@ -20,12 +20,12 @@ package text_test
import (
"context"
+
"github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -66,13 +66,15 @@ func (suite *TextStandardTestSuite) SetupSuite() {
}
func (suite *TextStandardTestSuite) SetupTest() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- federator := testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../testrig/media"), suite.db, fedWorker), nil, nil, fedWorker)
+ federator := testrig.NewTestFederator(&state, testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(nil, "../../testrig/media")), nil)
suite.parseMention = processing.GetParseMentionFunc(suite.db, federator)
suite.formatter = text.NewFormatter(suite.db)
diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go
index 9be1fdb90..0c866c7a8 100644
--- a/internal/timeline/get_test.go
+++ b/internal/timeline/get_test.go
@@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -42,10 +43,13 @@ func (suite *GetTestSuite) SetupSuite() {
}
func (suite *GetTestSuite) SetupTest() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go
index 692688aba..9d79f12c2 100644
--- a/internal/timeline/index_test.go
+++ b/internal/timeline/index_test.go
@@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -41,10 +42,13 @@ func (suite *IndexTestSuite) SetupSuite() {
}
func (suite *IndexTestSuite) SetupTest() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go
index 03804bf78..e033ffda4 100644
--- a/internal/timeline/manager_test.go
+++ b/internal/timeline/manager_test.go
@@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -39,10 +40,13 @@ func (suite *ManagerTestSuite) SetupSuite() {
}
func (suite *ManagerTestSuite) SetupTest() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
diff --git a/internal/timeline/prune_test.go b/internal/timeline/prune_test.go
index 9d539e0e0..48bba41dc 100644
--- a/internal/timeline/prune_test.go
+++ b/internal/timeline/prune_test.go
@@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -41,10 +42,13 @@ func (suite *PruneTestSuite) SetupSuite() {
}
func (suite *PruneTestSuite) SetupTest() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.filter = visibility.NewFilter(suite.db)
diff --git a/internal/trans/import_test.go b/internal/trans/import_test.go
index 128ac58a3..a53305c79 100644
--- a/internal/trans/import_test.go
+++ b/internal/trans/import_test.go
@@ -27,6 +27,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/trans"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -57,8 +58,11 @@ func (suite *ImportMinimalTestSuite) TestImportMinimalOK() {
suite.NotEmpty(b)
fmt.Println(string(b))
+ var state state.State
+ state.Caches.Init()
+
// create a new database with just the tables created, no entries
- newDB := testrig.NewTestDB()
+ newDB := testrig.NewTestDB(&state)
importer := trans.NewImporter(newDB)
err = importer.Import(ctx, tempFilePath)
diff --git a/internal/trans/trans_test.go b/internal/trans/trans_test.go
index 9364891a0..2b6bbb57b 100644
--- a/internal/trans/trans_test.go
+++ b/internal/trans/trans_test.go
@@ -22,6 +22,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -32,12 +33,15 @@ type TransTestSuite struct {
}
func (suite *TransTestSuite) SetupTest() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestConfig()
testrig.InitTestLog()
suite.testAccounts = testrig.NewTestAccounts()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
testrig.StandardDBSetup(suite.db, nil)
}
diff --git a/internal/typeutils/converter_test.go b/internal/typeutils/converter_test.go
index c6f3c2579..bc81a7c6d 100644
--- a/internal/typeutils/converter_test.go
+++ b/internal/typeutils/converter_test.go
@@ -23,6 +23,7 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -481,10 +482,13 @@ type TypeUtilsTestSuite struct {
}
func (suite *TypeUtilsTestSuite) SetupSuite() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
suite.testAccounts = testrig.NewTestAccounts()
suite.testStatuses = testrig.NewTestStatuses()
suite.testAttachments = testrig.NewTestAttachments()
diff --git a/internal/visibility/filter_test.go b/internal/visibility/filter_test.go
index bd7a8671e..9697dd72c 100644
--- a/internal/visibility/filter_test.go
+++ b/internal/visibility/filter_test.go
@@ -22,6 +22,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -60,10 +61,13 @@ func (suite *FilterStandardTestSuite) SetupSuite() {
}
func (suite *FilterStandardTestSuite) SetupTest() {
+ var state state.State
+ state.Caches.Init()
+
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB()
+ suite.db = testrig.NewTestDB(&state)
suite.filter = visibility.NewFilter(suite.db)
testrig.StandardDBSetup(suite.db, nil)
diff --git a/internal/workers/workers.go b/internal/workers/workers.go
index 77b3065ce..b29d115aa 100644
--- a/internal/workers/workers.go
+++ b/internal/workers/workers.go
@@ -19,20 +19,28 @@ along with this program. If not, see .
package workers
import (
+ "context"
"log"
"runtime"
"codeberg.org/gruf/go-runners"
"codeberg.org/gruf/go-sched"
+ "github.com/superseriousbusiness/gotosocial/internal/messages"
)
type Workers struct {
// Main task scheduler instance.
Scheduler sched.Scheduler
- // Processor / federator worker pools.
- // ClientAPI runners.WorkerPool
- // Federator runners.WorkerPool
+ // ClientAPI / federator worker pools.
+ ClientAPI runners.WorkerPool
+ Federator runners.WorkerPool
+
+ // Enqueue functions for clientAPI / federator worker pools,
+ // these are pointers to Processor{}.Enqueue___() msg functions.
+ // This prevents dependency cycling as Processor depends on Workers.
+ EnqueueClientAPI func(context.Context, messages.FromClientAPI)
+ EnqueueFederator func(context.Context, messages.FromFederator)
// Media manager worker pools.
Media runners.WorkerPool
@@ -50,13 +58,13 @@ func (w *Workers) Start() {
return w.Scheduler.Start(nil)
})
- // tryUntil("starting client API workerpool", 5, func() bool {
- // return w.ClientAPI.Start(4*maxprocs, 400*maxprocs)
- // })
+ tryUntil("starting client API workerpool", 5, func() bool {
+ return w.ClientAPI.Start(4*maxprocs, 400*maxprocs)
+ })
- // tryUntil("starting federator workerpool", 5, func() bool {
- // return w.Federator.Start(4*maxprocs, 400*maxprocs)
- // })
+ tryUntil("starting federator workerpool", 5, func() bool {
+ return w.Federator.Start(4*maxprocs, 400*maxprocs)
+ })
tryUntil("starting media workerpool", 5, func() bool {
return w.Media.Start(8*maxprocs, 80*maxprocs)
@@ -66,8 +74,8 @@ func (w *Workers) Start() {
// Stop will stop all of the contained worker pools (and global scheduler).
func (w *Workers) Stop() {
tryUntil("stopping scheduler", 5, w.Scheduler.Stop)
- // tryUntil("stopping client API workerpool", 5, w.ClientAPI.Stop)
- // tryUntil("stopping federator workerpool", 5, w.Federator.Stop)
+ tryUntil("stopping client API workerpool", 5, w.ClientAPI.Stop)
+ tryUntil("stopping federator workerpool", 5, w.Federator.Stop)
tryUntil("stopping media workerpool", 5, w.Media.Stop)
}
diff --git a/testrig/db.go b/testrig/db.go
index 8479347eb..1a29aa8b9 100644
--- a/testrig/db.go
+++ b/testrig/db.go
@@ -71,7 +71,7 @@ var testModels = []interface{}{
//
// If the environment variable GTS_DB_PORT is set, it will take that
// value as the port instead.
-func NewTestDB() db.DB {
+func NewTestDB(state *state.State) db.DB {
if alternateAddress := os.Getenv("GTS_DB_ADDRESS"); alternateAddress != "" {
config.SetDbAddress(alternateAddress)
}
@@ -88,10 +88,9 @@ func NewTestDB() db.DB {
config.SetDbPort(int(port))
}
- var state state.State
state.Caches.Init()
- testDB, err := bundb.NewBunDBService(context.Background(), &state)
+ testDB, err := bundb.NewBunDBService(context.Background(), state)
if err != nil {
log.Panic(nil, err)
}
diff --git a/testrig/federatingdb.go b/testrig/federatingdb.go
index 9b1f1961e..27adc4c51 100644
--- a/testrig/federatingdb.go
+++ b/testrig/federatingdb.go
@@ -19,13 +19,11 @@
package testrig
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
// NewTestFederatingDB returns a federating DB with the underlying db
-func NewTestFederatingDB(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federatingdb.DB {
- return federatingdb.New(db, fedWorker, NewTestTypeConverter(db))
+func NewTestFederatingDB(state *state.State) federatingdb.DB {
+ return federatingdb.New(state, NewTestTypeConverter(state.DB))
}
diff --git a/testrig/federator.go b/testrig/federator.go
index 605a2c8f3..bc150633e 100644
--- a/testrig/federator.go
+++ b/testrig/federator.go
@@ -19,16 +19,13 @@
package testrig
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/transport"
)
// NewTestFederator returns a federator with the given database and (mock!!) transport controller.
-func NewTestFederator(db db.DB, tc transport.Controller, storage *storage.Driver, mediaManager media.Manager, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federation.Federator {
- return federation.NewFederator(db, NewTestFederatingDB(db, fedWorker), tc, NewTestTypeConverter(db), mediaManager)
+func NewTestFederator(state *state.State, tc transport.Controller, mediaManager media.Manager) federation.Federator {
+ return federation.NewFederator(state.DB, NewTestFederatingDB(state), tc, NewTestTypeConverter(state.DB), mediaManager)
}
diff --git a/testrig/mediahandler.go b/testrig/mediahandler.go
index a1863218c..b4b992b0b 100644
--- a/testrig/mediahandler.go
+++ b/testrig/mediahandler.go
@@ -19,17 +19,12 @@
package testrig
import (
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/state"
- "github.com/superseriousbusiness/gotosocial/internal/storage"
)
// NewTestMediaManager returns a media handler with the default test config, and the given db and storage.
-func NewTestMediaManager(db db.DB, storage *storage.Driver) media.Manager {
- var state state.State
- state.DB = db
- state.Storage = storage
- state.Workers.Start()
- return media.NewManager(&state)
+func NewTestMediaManager(state *state.State) media.Manager {
+ StartWorkers(state) // ensure started
+ return media.NewManager(state)
}
diff --git a/testrig/processor.go b/testrig/processor.go
index f451d4ad0..856ee523d 100644
--- a/testrig/processor.go
+++ b/testrig/processor.go
@@ -19,17 +19,17 @@
package testrig
import (
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
// NewTestProcessor returns a Processor suitable for testing purposes
-func NewTestProcessor(db db.DB, storage *storage.Driver, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], fedWorker *concurrency.WorkerPool[messages.FromFederator]) *processing.Processor {
- return processing.NewProcessor(NewTestTypeConverter(db), federator, NewTestOauthServer(db), mediaManager, storage, db, emailSender, clientWorker, fedWorker)
+func NewTestProcessor(state *state.State, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager) *processing.Processor {
+ p := processing.NewProcessor(NewTestTypeConverter(state.DB), federator, NewTestOauthServer(state.DB), mediaManager, state, emailSender)
+ state.Workers.EnqueueClientAPI = p.EnqueueClientAPI
+ state.Workers.EnqueueFederator = p.EnqueueFederator
+ return p
}
diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go
index 7565a741c..9657205f6 100644
--- a/testrig/transportcontroller.go
+++ b/testrig/transportcontroller.go
@@ -30,12 +30,10 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/transport"
)
@@ -53,8 +51,8 @@ const (
// Unlike the other test interfaces provided in this package, you'll probably want to call this function
// PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)
// basis.
-func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) transport.Controller {
- return transport.NewController(db, NewTestFederatingDB(db, fedWorker), &federation.Clock{}, client)
+func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller {
+ return transport.NewController(state.DB, NewTestFederatingDB(state), &federation.Clock{}, client)
}
type MockHTTPClient struct {
diff --git a/testrig/util.go b/testrig/util.go
index cc392b315..0cda93024 100644
--- a/testrig/util.go
+++ b/testrig/util.go
@@ -20,13 +20,34 @@ package testrig
import (
"bytes"
+ "context"
"io"
"mime/multipart"
"net/url"
"os"
"time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
+func StartWorkers(state *state.State) {
+ state.Workers.EnqueueClientAPI = func(context.Context, messages.FromClientAPI) {}
+ state.Workers.EnqueueFederator = func(context.Context, messages.FromFederator) {}
+
+ _ = state.Workers.Scheduler.Start(nil)
+ _ = state.Workers.ClientAPI.Start(1, 10)
+ _ = state.Workers.Federator.Start(1, 10)
+ _ = state.Workers.Media.Start(1, 10)
+}
+
+func StopWorkers(state *state.State) {
+ _ = state.Workers.Scheduler.Stop()
+ _ = state.Workers.ClientAPI.Stop()
+ _ = state.Workers.Federator.Stop()
+ _ = state.Workers.Media.Stop()
+}
+
// CreateMultipartFormData is a handy function for taking a fieldname and a filename, and creating a multipart form bytes buffer
// with the file contents set in the given fieldname. The extraFields param can be used to add extra FormFields to the request, as necessary.
// The returned bytes.Buffer b can be used like so: