diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go
index 29c05ba1a..4a8993598 100644
--- a/cmd/gotosocial/action/server/server.go
+++ b/cmd/gotosocial/action/server/server.go
@@ -21,7 +21,6 @@ package server
import (
"context"
"fmt"
- "net/http"
"os"
"os/signal"
"path"
@@ -56,12 +55,14 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
+ "github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
@@ -71,7 +72,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/web"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Start creates and starts a gotosocial server
@@ -93,8 +93,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
// NOTE: these MUST NOT be used until they are passed to the
// processor and it is started. The reason being that the processor
// sets the Worker process functions and start the underlying pools
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
federatingDB := federatingdb.New(dbService, fedWorker)
@@ -120,13 +120,16 @@ var Start action.GTSAction = func(ctx context.Context) error {
return fmt.Errorf("error creating storage backend: %s", err)
}
+ // Build HTTP client (TODO: add configurables here)
+ client := httpclient.New(httpclient.Config{})
+
// build backend handlers
mediaManager, err := media.NewManager(dbService, storage)
if err != nil {
return fmt.Errorf("error creating media manager: %s", err)
}
oauthServer := oauth.New(ctx, dbService)
- transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, http.DefaultClient)
+ transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client)
federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager)
// decide whether to create a noop email sender (won't send emails) or a real one
diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go
index 010c730a0..cb587c51d 100644
--- a/cmd/gotosocial/action/testrig/testrig.go
+++ b/cmd/gotosocial/action/testrig/testrig.go
@@ -54,11 +54,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/web"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -74,8 +74,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
testrig.StandardStorageSetup(storageBackend, "./testrig/media")
// Create client API and federator worker pools
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// build backend handlers
oauthServer := testrig.NewTestOauthServer(dbService)
diff --git a/go.mod b/go.mod
index bd10b6d73..1646d47e6 100644
--- a/go.mod
+++ b/go.mod
@@ -3,7 +3,10 @@ module github.com/superseriousbusiness/gotosocial
go 1.18
require (
+ codeberg.org/gruf/go-byteutil v1.0.1
+ codeberg.org/gruf/go-cache/v2 v2.0.1
codeberg.org/gruf/go-debug v1.1.2
+ codeberg.org/gruf/go-errors/v2 v2.0.1
codeberg.org/gruf/go-mutexes v1.1.2
codeberg.org/gruf/go-runners v1.2.1
codeberg.org/gruf/go-store v1.3.7
@@ -52,8 +55,6 @@ require (
require (
codeberg.org/gruf/go-bitutil v1.0.0 // indirect
codeberg.org/gruf/go-bytes v1.0.2 // indirect
- codeberg.org/gruf/go-byteutil v1.0.0 // indirect
- codeberg.org/gruf/go-errors/v2 v2.0.1 // indirect
codeberg.org/gruf/go-fastcopy v1.1.1 // indirect
codeberg.org/gruf/go-fastpath v1.0.3 // indirect
codeberg.org/gruf/go-hashenc v1.0.2 // indirect
diff --git a/go.sum b/go.sum
index 0d7e18434..435e7a7f7 100644
--- a/go.sum
+++ b/go.sum
@@ -40,9 +40,12 @@ codeberg.org/gruf/go-bitutil v1.0.0/go.mod h1:sb8IjlDnjVTz8zPK/8lmHesKxY0Yb3iqHW
codeberg.org/gruf/go-bytes v1.0.0/go.mod h1:1v/ibfaosfXSZtRdW2rWaVrDXMc9E3bsi/M9Ekx39cg=
codeberg.org/gruf/go-bytes v1.0.2 h1:malqE42Ni+h1nnYWBUAJaDDtEzF4aeN4uPN8DfMNNvo=
codeberg.org/gruf/go-bytes v1.0.2/go.mod h1:1v/ibfaosfXSZtRdW2rWaVrDXMc9E3bsi/M9Ekx39cg=
-codeberg.org/gruf/go-byteutil v1.0.0 h1:xgKFNj/gH1r3yRo7gnyR4qrAKyeWCXs6B19ISX0DUAY=
codeberg.org/gruf/go-byteutil v1.0.0/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
+codeberg.org/gruf/go-byteutil v1.0.1 h1:cOSaqe2aytOTAC5NM62LI0w8qPfJ9n2BBddk5KyMgd0=
+codeberg.org/gruf/go-byteutil v1.0.1/go.mod h1:cWM3tgMCroSzqoBXUXMhvxTxYJp+TbCr6ioISRY5vSU=
codeberg.org/gruf/go-cache v1.1.2/go.mod h1:/Dbc+xU72Op3hMn6x2PXF3NE9uIDFeS+sXPF00hN/7o=
+codeberg.org/gruf/go-cache/v2 v2.0.1 h1:dyyfn6W6jfUlD/HWu5oz48sowSgsfKKeg2lU6T0gRww=
+codeberg.org/gruf/go-cache/v2 v2.0.1/go.mod h1:VyfrDnPVUXUKYVkXnFOHRO1EoN+8zrTC9jRU6VmL3p0=
codeberg.org/gruf/go-debug v1.1.2 h1:7Tqkktg60M/4WtXTTNUFH2T/6irBw4tI4viv7IRLZDE=
codeberg.org/gruf/go-debug v1.1.2/go.mod h1:N+vSy9uJBQgpQcJUqjctvqFz7tBHJf+S/PIjLILzpLg=
codeberg.org/gruf/go-errors/v2 v2.0.0/go.mod h1:ZRhbdhvgoUA3Yw6e56kd9Ox984RrvbEFC2pOXyHDJP4=
diff --git a/internal/api/client/account/account_test.go b/internal/api/client/account/account_test.go
index d65b49550..d6bb5a5c0 100644
--- a/internal/api/client/account/account_test.go
+++ b/internal/api/client/account/account_test.go
@@ -11,6 +11,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -20,7 +21,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -62,8 +62,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/admin/admin_test.go b/internal/api/client/admin/admin_test.go
index 578ab167c..11e2f8354 100644
--- a/internal/api/client/admin/admin_test.go
+++ b/internal/api/client/admin/admin_test.go
@@ -29,6 +29,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -38,7 +39,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -80,8 +80,8 @@ func (suite *AdminStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/fileserver/servefile_test.go b/internal/api/client/fileserver/servefile_test.go
index 49d813981..d7de2f4f9 100644
--- a/internal/api/client/fileserver/servefile_test.go
+++ b/internal/api/client/fileserver/servefile_test.go
@@ -31,6 +31,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -77,8 +77,8 @@ func (suite *ServeFileTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/followrequest/followrequest_test.go b/internal/api/client/followrequest/followrequest_test.go
index 072025931..14b5656b6 100644
--- a/internal/api/client/followrequest/followrequest_test.go
+++ b/internal/api/client/followrequest/followrequest_test.go
@@ -28,6 +28,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -77,8 +77,8 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go
index 4d08697ef..e16b9f5eb 100644
--- a/internal/api/client/media/mediacreate_test.go
+++ b/internal/api/client/media/mediacreate_test.go
@@ -37,6 +37,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -47,7 +48,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -84,8 +84,8 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go
index b87e6ec8d..a87718438 100644
--- a/internal/api/client/media/mediaupdate_test.go
+++ b/internal/api/client/media/mediaupdate_test.go
@@ -35,6 +35,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -45,7 +46,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -82,8 +82,8 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/status/status_test.go b/internal/api/client/status/status_test.go
index a4a56aa0b..e2e2819b5 100644
--- a/internal/api/client/status/status_test.go
+++ b/internal/api/client/status/status_test.go
@@ -32,6 +32,7 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -90,8 +90,8 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(suite.testHttpClient(), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/client/user/user_test.go b/internal/api/client/user/user_test.go
index b0fd2b2e9..6e9c46525 100644
--- a/internal/api/client/user/user_test.go
+++ b/internal/api/client/user/user_test.go
@@ -22,6 +22,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -58,8 +58,8 @@ type UserStandardTestSuite struct {
func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
diff --git a/internal/api/s2s/user/inboxpost_test.go b/internal/api/s2s/user/inboxpost_test.go
index 6f2909430..388a9fbbb 100644
--- a/internal/api/s2s/user/inboxpost_test.go
+++ b/internal/api/s2s/user/inboxpost_test.go
@@ -33,11 +33,11 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -85,8 +85,8 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -188,8 +188,8 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -281,8 +281,8 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -403,8 +403,8 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/outboxget_test.go b/internal/api/s2s/user/outboxget_test.go
index ea9259b0f..79122731f 100644
--- a/internal/api/s2s/user/outboxget_test.go
+++ b/internal/api/s2s/user/outboxget_test.go
@@ -31,8 +31,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -46,8 +46,8 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -104,8 +104,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -162,8 +162,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/repliesget_test.go b/internal/api/s2s/user/repliesget_test.go
index 4b8364318..845c07bdb 100644
--- a/internal/api/s2s/user/repliesget_test.go
+++ b/internal/api/s2s/user/repliesget_test.go
@@ -33,8 +33,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -49,8 +49,8 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -113,8 +113,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -180,8 +180,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/statusget_test.go b/internal/api/s2s/user/statusget_test.go
index c28e4e567..6696bd7e9 100644
--- a/internal/api/s2s/user/statusget_test.go
+++ b/internal/api/s2s/user/statusget_test.go
@@ -32,8 +32,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -48,8 +48,8 @@ func (suite *StatusGetTestSuite) TestGetStatus() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -116,8 +116,8 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/user_test.go b/internal/api/s2s/user/user_test.go
index 1ed960544..e8d305d06 100644
--- a/internal/api/s2s/user/user_test.go
+++ b/internal/api/s2s/user/user_test.go
@@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -32,7 +33,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -78,8 +78,8 @@ func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)
diff --git a/internal/api/s2s/user/userget_test.go b/internal/api/s2s/user/userget_test.go
index 5c9e4f0d8..5ac2197ff 100644
--- a/internal/api/s2s/user/userget_test.go
+++ b/internal/api/s2s/user/userget_test.go
@@ -33,9 +33,9 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -49,8 +49,8 @@ func (suite *UserGetTestSuite) TestGetUser() {
signedRequest := derefRequests["foss_satan_dereference_zork"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -130,8 +130,8 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
derefRequests := testrig.NewTestDereferenceRequests(suite.testAccounts)
signedRequest := derefRequests["foss_satan_dereference_zork_public_key"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/webfinger/webfinger_test.go b/internal/api/s2s/webfinger/webfinger_test.go
index 1f597d3f9..0df50c503 100644
--- a/internal/api/s2s/webfinger/webfinger_test.go
+++ b/internal/api/s2s/webfinger/webfinger_test.go
@@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -81,8 +81,8 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)
diff --git a/internal/api/s2s/webfinger/webfingerget_test.go b/internal/api/s2s/webfinger/webfingerget_test.go
index 55de30f34..7871b6a3f 100644
--- a/internal/api/s2s/webfinger/webfingerget_test.go
+++ b/internal/api/s2s/webfinger/webfingerget_test.go
@@ -31,10 +31,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -71,8 +71,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() {
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
@@ -107,8 +107,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
diff --git a/internal/worker/workers.go b/internal/concurrency/workers.go
similarity index 78%
rename from internal/worker/workers.go
rename to internal/concurrency/workers.go
index 6adf9ad30..2e344aece 100644
--- a/internal/worker/workers.go
+++ b/internal/concurrency/workers.go
@@ -1,4 +1,4 @@
-package worker
+package concurrency
import (
"context"
@@ -12,17 +12,17 @@ import (
"github.com/sirupsen/logrus"
)
-// Worker represents a proccessor for MsgType objects, using a worker pool to allocate resources.
-type Worker[MsgType any] struct {
+// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
+type WorkerPool[MsgType any] struct {
workers runners.WorkerPool
process func(context.Context, MsgType) error
prefix string // contains type prefix for logging
}
-// New returns a new Worker[MsgType] with given number of workers and queue ratio,
+// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
-func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
+func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
var zero MsgType
if workers < 1 {
@@ -38,7 +38,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
msgType := reflect.TypeOf(zero).String()
_, msgType = path.Split(msgType)
- w := &Worker[MsgType]{
+ w := &WorkerPool[MsgType]{
workers: runners.NewWorkerPool(workers, workers*queueRatio),
process: nil,
prefix: fmt.Sprintf("worker.Worker[%s]", msgType),
@@ -55,7 +55,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
}
// Start will attempt to start the underlying worker pool, or return error.
-func (w *Worker[MsgType]) Start() error {
+func (w *WorkerPool[MsgType]) Start() error {
logrus.Infof("%s starting", w.prefix)
// Check processor was set
@@ -72,7 +72,7 @@ func (w *Worker[MsgType]) Start() error {
}
// Stop will attempt to stop the underlying worker pool, or return error.
-func (w *Worker[MsgType]) Stop() error {
+func (w *WorkerPool[MsgType]) Stop() error {
logrus.Infof("%s stopping", w.prefix)
// Attempt to stop pool
@@ -84,7 +84,7 @@ func (w *Worker[MsgType]) Stop() error {
}
// SetProcessor will set the Worker's processor function, which is called for each queued message.
-func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
+func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
if w.process != nil {
logrus.Panicf("%s Worker.process is already set", w.prefix)
}
@@ -92,7 +92,7 @@ func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error)
}
// Queue will queue provided message to be processed with there's a free worker.
-func (w *Worker[MsgType]) Queue(msg MsgType) {
+func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
logrus.Tracef("%s queueing message (workers=%d queue=%d): %+v",
w.prefix, w.workers.Workers(), w.workers.Queue(), msg,
)
diff --git a/internal/federation/dereferencing/dereferencer_test.go b/internal/federation/dereferencing/dereferencer_test.go
index 441019866..339490e5d 100644
--- a/internal/federation/dereferencing/dereferencer_test.go
+++ b/internal/federation/dereferencing/dereferencer_test.go
@@ -29,12 +29,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -150,7 +150,7 @@ func (suite *DereferencerStandardTestSuite) mockTransportController() transport.
return response, nil
}
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}
diff --git a/internal/federation/federatingactor_test.go b/internal/federation/federatingactor_test.go
index 4039783a4..fdf907030 100644
--- a/internal/federation/federatingactor_test.go
+++ b/internal/federation/federatingactor_test.go
@@ -28,10 +28,10 @@ import (
"time"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -57,7 +57,7 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}
@@ -112,7 +112,7 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}
diff --git a/internal/federation/federatingdb/db.go b/internal/federation/federatingdb/db.go
index 60f09b909..cbe65e922 100644
--- a/internal/federation/federatingdb/db.go
+++ b/internal/federation/federatingdb/db.go
@@ -24,10 +24,10 @@ import (
"codeberg.org/gruf/go-mutexes"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams/vocab"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// DB wraps the pub.Database interface with a couple of custom functions for GoToSocial.
@@ -44,12 +44,12 @@ type DB interface {
type federatingDB struct {
locks mutexes.MutexMap
db db.DB
- fedWorker *worker.Worker[messages.FromFederator]
+ fedWorker *concurrency.WorkerPool[messages.FromFederator]
typeConverter typeutils.TypeConverter
}
// New returns a DB interface using the given database and config
-func New(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) DB {
+func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) DB {
fdb := federatingDB{
locks: mutexes.NewMap(-1, -1), // use defaults
db: db,
diff --git a/internal/federation/federatingdb/federatingdb_test.go b/internal/federation/federatingdb/federatingdb_test.go
index d53294c1c..8e6c1802d 100644
--- a/internal/federation/federatingdb/federatingdb_test.go
+++ b/internal/federation/federatingdb/federatingdb_test.go
@@ -23,12 +23,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -36,7 +36,7 @@ type FederatingDBTestSuite struct {
suite.Suite
db db.DB
tc typeutils.TypeConverter
- fedWorker *worker.Worker[messages.FromFederator]
+ fedWorker *concurrency.WorkerPool[messages.FromFederator]
fromFederator chan messages.FromFederator
federatingDB federatingdb.DB
@@ -65,7 +65,7 @@ func (suite *FederatingDBTestSuite) SetupSuite() {
func (suite *FederatingDBTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.fedWorker = worker.New[messages.FromFederator](-1, -1)
+ suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.fromFederator = make(chan messages.FromFederator, 10)
suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error {
suite.fromFederator <- msg
diff --git a/internal/federation/federatingprotocol_test.go b/internal/federation/federatingprotocol_test.go
index 09817cff3..b4769a70f 100644
--- a/internal/federation/federatingprotocol_test.go
+++ b/internal/federation/federatingprotocol_test.go
@@ -28,10 +28,10 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -44,7 +44,7 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook() {
// the activity we're gonna use
activity := suite.testActivities["dm_for_zork"]
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
@@ -78,7 +78,7 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
// now setup module being tested, with the mock transport controller
diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go
new file mode 100644
index 000000000..1a1f5e53b
--- /dev/null
+++ b/internal/httpclient/client.go
@@ -0,0 +1,199 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package httpclient
+
+import (
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "net/netip"
+ "runtime"
+ "time"
+)
+
+// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net.
+var ErrReservedAddr = errors.New("dial within blocked / reserved IP range")
+
+// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB).
+var ErrBodyTooLarge = errors.New("body size too large")
+
+// dialer is the base net.Dialer used by all package-created http.Transports.
+var dialer = &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ Resolver: &net.Resolver{Dial: nil},
+}
+
+// Config provides configuration details for setting up a new
+// instance of httpclient.Client{}. Within are a subset of the
+// configuration values passed to initialized http.Transport{}
+// and http.Client{}, along with httpclient.Client{} specific.
+type Config struct {
+ // MaxOpenConns limits the max number of concurrent open connections.
+ MaxOpenConns int
+
+ // MaxIdleConns: see http.Transport{}.MaxIdleConns.
+ MaxIdleConns int
+
+ // ReadBufferSize: see http.Transport{}.ReadBufferSize.
+ ReadBufferSize int
+
+ // WriteBufferSize: see http.Transport{}.WriteBufferSize.
+ WriteBufferSize int
+
+ // MaxBodySize determines the maximum fetchable body size.
+ MaxBodySize int64
+
+ // Timeout: see http.Client{}.Timeout.
+ Timeout time.Duration
+
+ // DisableCompression: see http.Transport{}.DisableCompression.
+ DisableCompression bool
+
+ // AllowRanges allows outgoing communications to given IP nets.
+ AllowRanges []netip.Prefix
+
+ // BlockRanges blocks outgoing communiciations to given IP nets.
+ BlockRanges []netip.Prefix
+}
+
+// Client wraps an underlying http.Client{} to provide the following:
+// - setting a maximum received request body size, returning error on
+// large content lengths, and using a limited reader in all other
+// cases to protect against forged / unknown content-lengths
+// - protection from server side request forgery (SSRF) by only dialing
+// out to known public IP prefixes, configurable with allows/blocks
+// - limit number of concurrent requests, else blocking until a slot
+// is available (context channels still respected)
+type Client struct {
+ client http.Client
+ queue chan struct{}
+ bmax int64
+}
+
+// New returns a new instance of Client initialized using configuration.
+func New(cfg Config) *Client {
+ var c Client
+
+ // Copy global
+ d := dialer
+
+ if cfg.MaxOpenConns <= 0 {
+ // By default base this value on GOMAXPROCS.
+ maxprocs := runtime.GOMAXPROCS(0)
+ cfg.MaxOpenConns = maxprocs * 10
+ }
+
+ if cfg.MaxIdleConns <= 0 {
+ // By default base this value on MaxOpenConns
+ cfg.MaxIdleConns = cfg.MaxOpenConns * 10
+ }
+
+ if cfg.MaxBodySize <= 0 {
+ // By default set this to a reasonable 40MB
+ cfg.MaxBodySize = 40 * 1024 * 1024
+ }
+
+ // Protect dialer with IP range sanitizer
+ d.Control = (&sanitizer{
+ allow: cfg.AllowRanges,
+ block: cfg.BlockRanges,
+ }).Sanitize
+
+ // Prepare client fields
+ c.bmax = cfg.MaxBodySize
+ c.queue = make(chan struct{}, cfg.MaxOpenConns)
+ c.client.Timeout = cfg.Timeout
+
+ // Set underlying HTTP client roundtripper
+ c.client.Transport = &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ ForceAttemptHTTP2: true,
+ DialContext: d.DialContext,
+ MaxIdleConns: cfg.MaxIdleConns,
+ IdleConnTimeout: 90 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+ ReadBufferSize: cfg.ReadBufferSize,
+ WriteBufferSize: cfg.WriteBufferSize,
+ DisableCompression: cfg.DisableCompression,
+ }
+
+ return &c
+}
+
+// Do will perform given request when an available slot in the queue is available,
+// and block until this time. For returned values, this follows the same semantics
+// as the standard http.Client{}.Do() implementation except that response body will
+// be wrapped by an io.LimitReader() to limit response body sizes.
+func (c *Client) Do(req *http.Request) (*http.Response, error) {
+ select {
+ // Request context cancelled
+ case <-req.Context().Done():
+ return nil, req.Context().Err()
+
+ // Slot in queue acquired
+ case c.queue <- struct{}{}:
+ // NOTE:
+ // Ideally here we would set the slot release to happen either
+ // on error return, or via callback from the response body closer.
+ // However when implementing this, there appear deadlocks between
+ // the channel queue here and the media manager worker pool. So
+ // currently we only place a limit on connections dialing out, but
+ // there may still be more connections open than len(c.queue) given
+ // that connections may not be closed until response body is closed.
+ // The current implementation will reduce the viability of denial of
+ // service attacks, but if there are future issues heed this advice :]
+ defer func() { <-c.queue }()
+ }
+
+ // Perform the HTTP request
+ rsp, err := c.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check response body not too large
+ if rsp.ContentLength > c.bmax {
+ return nil, ErrBodyTooLarge
+ }
+
+ // Seperate the body implementers
+ rbody := (io.Reader)(rsp.Body)
+ cbody := (io.Closer)(rsp.Body)
+
+ var limit int64
+
+ if limit = rsp.ContentLength; limit < 0 {
+ // If unknown, use max as reader limit
+ limit = c.bmax
+ }
+
+ // Don't trust them, limit body reads
+ rbody = io.LimitReader(rbody, limit)
+
+ // Wrap body with limit
+ rsp.Body = &struct {
+ io.Reader
+ io.Closer
+ }{rbody, cbody}
+
+ return rsp, nil
+}
diff --git a/internal/httpclient/client_test.go b/internal/httpclient/client_test.go
new file mode 100644
index 000000000..dc190d430
--- /dev/null
+++ b/internal/httpclient/client_test.go
@@ -0,0 +1,154 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package httpclient_test
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/netip"
+ "testing"
+
+ "github.com/superseriousbusiness/gotosocial/internal/httpclient"
+)
+
+var privateIPs = []string{
+ "http://127.0.0.1:80",
+ "http://0.0.0.0:80",
+ "http://192.168.0.1:80",
+ "http://192.168.1.0:80",
+ "http://10.0.0.0:80",
+ "http://172.16.0.0:80",
+ "http://10.255.255.255:80",
+ "http://172.31.255.255:80",
+ "http://255.255.255.255:80",
+}
+
+var bodies = []string{
+ "hello world!",
+ "{}",
+ `{"key": "value", "some": "kinda bullshit"}`,
+ "body with\r\nnewlines",
+}
+
+// Note:
+// There is no test for the .MaxOpenConns implementation
+// in the httpclient.Client{}, due to the difficult to test
+// this. The block is only held for the actual dial out to
+// the connection, so the usual test of blocking and holding
+// open this queue slot to check we can't open another isn't
+// an easy test here.
+
+func TestHTTPClientSmallBody(t *testing.T) {
+ for _, body := range bodies {
+ _TestHTTPClientWithBody(t, []byte(body), int(^uint16(0)))
+ }
+}
+
+func TestHTTPClientExactBody(t *testing.T) {
+ for _, body := range bodies {
+ _TestHTTPClientWithBody(t, []byte(body), len(body))
+ }
+}
+
+func TestHTTPClientLargeBody(t *testing.T) {
+ for _, body := range bodies {
+ _TestHTTPClientWithBody(t, []byte(body), len(body)-1)
+ }
+}
+
+func _TestHTTPClientWithBody(t *testing.T, body []byte, max int) {
+ var (
+ handler http.HandlerFunc
+
+ expect []byte
+
+ expectErr error
+ )
+
+ // If this is a larger body, reslice and
+ // set error so we know what to expect
+ expect = body
+ if max < len(body) {
+ expect = expect[:max]
+ expectErr = httpclient.ErrBodyTooLarge
+ }
+
+ // Create new HTTP client with maximum body size
+ client := httpclient.New(httpclient.Config{
+ MaxBodySize: int64(max),
+ DisableCompression: true,
+ AllowRanges: []netip.Prefix{
+ // Loopback (used by server)
+ netip.MustParsePrefix("127.0.0.1/8"),
+ },
+ })
+
+ // Set simple body-writing test handler
+ handler = func(rw http.ResponseWriter, r *http.Request) {
+ _, _ = rw.Write(body)
+ }
+
+ // Start the test server
+ srv := httptest.NewServer(handler)
+ defer srv.Close()
+
+ // Wrap body to provide reader iface
+ rbody := bytes.NewReader(body)
+
+ // Create the test HTTP request
+ req, _ := http.NewRequest("POST", srv.URL, rbody)
+
+ // Perform the test request
+ rsp, err := client.Do(req)
+ if !errors.Is(err, expectErr) {
+ t.Fatalf("error performing client request: %v", err)
+ } else if err != nil {
+ return // expected error
+ }
+ defer rsp.Body.Close()
+
+ // Read response body into memory
+ check, err := io.ReadAll(rsp.Body)
+ if err != nil {
+ t.Fatalf("error reading response body: %v", err)
+ }
+
+ // Check actual response body matches expected
+ if !bytes.Equal(expect, check) {
+ t.Errorf("response body did not match expected: expect=%q actual=%q", string(expect), string(check))
+ }
+}
+
+func TestHTTPClientPrivateIP(t *testing.T) {
+ client := httpclient.New(httpclient.Config{})
+
+ for _, addr := range privateIPs {
+ // Prepare request to private IP
+ req, _ := http.NewRequest("GET", addr, nil)
+
+ // Perform the HTTP request
+ _, err := client.Do(req)
+ if !errors.Is(err, httpclient.ErrReservedAddr) {
+ t.Errorf("dialing private address did not return expected error: %v", err)
+ }
+ }
+}
diff --git a/internal/httpclient/sanitizer.go b/internal/httpclient/sanitizer.go
new file mode 100644
index 000000000..6eef6898a
--- /dev/null
+++ b/internal/httpclient/sanitizer.go
@@ -0,0 +1,64 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package httpclient
+
+import (
+ "net/netip"
+ "syscall"
+
+ "github.com/superseriousbusiness/gotosocial/internal/netutil"
+)
+
+type sanitizer struct {
+ allow []netip.Prefix
+ block []netip.Prefix
+}
+
+// Sanitize implements the required net.Dialer.Control function signature.
+func (s *sanitizer) Sanitize(ntwrk, addr string, _ syscall.RawConn) error {
+ // Parse IP+port from addr
+ ipport, err := netip.ParseAddrPort(addr)
+ if err != nil {
+ return err
+ }
+
+ // Seperate the IP
+ ip := ipport.Addr()
+
+ // Check if this is explicitly allowed
+ for i := 0; i < len(s.allow); i++ {
+ if s.allow[i].Contains(ip) {
+ return nil
+ }
+ }
+
+ // Now check if explicity blocked
+ for i := 0; i < len(s.block); i++ {
+ if s.block[i].Contains(ip) {
+ return ErrReservedAddr
+ }
+ }
+
+ // Validate this is a safe IP
+ if !netutil.ValidateIP(ip) {
+ return ErrReservedAddr
+ }
+
+ return nil
+}
diff --git a/internal/media/manager.go b/internal/media/manager.go
index 174fca8e2..5b4a01021 100644
--- a/internal/media/manager.go
+++ b/internal/media/manager.go
@@ -27,9 +27,9 @@ import (
"github.com/robfig/cron/v3"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Manager provides an interface for managing media: parsing, storing, and retrieving media objects like photos, videos, and gifs.
@@ -79,8 +79,8 @@ type Manager interface {
type manager struct {
db db.DB
storage *kv.KVStore
- emojiWorker *worker.Worker[*ProcessingEmoji]
- mediaWorker *worker.Worker[*ProcessingMedia]
+ emojiWorker *concurrency.WorkerPool[*ProcessingEmoji]
+ mediaWorker *concurrency.WorkerPool[*ProcessingMedia]
stopCronJobs func() error
}
@@ -89,7 +89,7 @@ type manager struct {
// A worker pool will also be initialized for the manager, to ensure that only
// a limited number of media will be processed in parallel. The numbers of workers
// is determined from the $GOMAXPROCS environment variable (usually no. CPU cores).
-// See internal/worker.New() documentation for further information.
+// See internal/concurrency.NewWorkerPool() documentation for further information.
func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
m := &manager{
db: database,
@@ -97,7 +97,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
}
// Prepare the media worker pool
- m.mediaWorker = worker.New[*ProcessingMedia](-1, 10)
+ m.mediaWorker = concurrency.NewWorkerPool[*ProcessingMedia](-1, 10)
m.mediaWorker.SetProcessor(func(ctx context.Context, media *ProcessingMedia) error {
if err := ctx.Err(); err != nil {
return err
@@ -109,7 +109,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
})
// Prepare the emoji worker pool
- m.emojiWorker = worker.New[*ProcessingEmoji](-1, 10)
+ m.emojiWorker = concurrency.NewWorkerPool[*ProcessingEmoji](-1, 10)
m.emojiWorker.SetProcessor(func(ctx context.Context, emoji *ProcessingEmoji) error {
if err := ctx.Err(); err != nil {
return err
diff --git a/internal/netutil/validate.go b/internal/netutil/validate.go
new file mode 100644
index 000000000..27cc9ba4a
--- /dev/null
+++ b/internal/netutil/validate.go
@@ -0,0 +1,78 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package netutil
+
+import (
+ "net/netip"
+)
+
+var (
+ // IPv6GlobalUnicast is the global IPv6 unicast IP prefix.
+ IPv6GlobalUnicast = netip.MustParsePrefix("ff00::/8")
+
+ // IPvReserved contains IPv4 reserved IP prefixes.
+ IPv4Reserved = [...]netip.Prefix{
+ netip.MustParsePrefix("0.0.0.0/8"), // Current network
+ netip.MustParsePrefix("10.0.0.0/8"), // Private
+ netip.MustParsePrefix("100.64.0.0/10"), // RFC6598
+ netip.MustParsePrefix("127.0.0.0/8"), // Loopback
+ netip.MustParsePrefix("169.254.0.0/16"), // Link-local
+ netip.MustParsePrefix("172.16.0.0/12"), // Private
+ netip.MustParsePrefix("192.0.0.0/24"), // RFC6890
+ netip.MustParsePrefix("192.0.2.0/24"), // Test, doc, examples
+ netip.MustParsePrefix("192.88.99.0/24"), // IPv6 to IPv4 relay
+ netip.MustParsePrefix("192.168.0.0/16"), // Private
+ netip.MustParsePrefix("198.18.0.0/15"), // Benchmarking tests
+ netip.MustParsePrefix("198.51.100.0/24"), // Test, doc, examples
+ netip.MustParsePrefix("203.0.113.0/24"), // Test, doc, examples
+ netip.MustParsePrefix("224.0.0.0/4"), // Multicast
+ netip.MustParsePrefix("240.0.0.0/4"), // Reserved (includes broadcast / 255.255.255.255)
+ }
+)
+
+// ValidateAddr will parse a netip.AddrPort from string, and return the result of ValidateIP() on addr.
+func ValidateAddr(s string) bool {
+ ipport, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return false
+ }
+ return ValidateIP(ipport.Addr())
+}
+
+// ValidateIP returns whether IP is an IPv4/6 address in non-reserved, public ranges.
+func ValidateIP(ip netip.Addr) bool {
+ switch {
+ // IPv4: check if IPv4 in reserved nets
+ case ip.Is4():
+ for _, reserved := range IPv4Reserved {
+ if reserved.Contains(ip) {
+ return false
+ }
+ }
+ return true
+
+ // IPv6: check if in global unicast (public internet)
+ case ip.Is6():
+ return IPv6GlobalUnicast.Contains(ip)
+
+ // Assume malicious by default
+ default:
+ return false
+ }
+}
diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go
index c49df1a1a..7668da02c 100644
--- a/internal/processing/account/account.go
+++ b/internal/processing/account/account.go
@@ -23,6 +23,7 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
@@ -33,7 +34,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/oauth2/v4"
)
@@ -84,7 +84,7 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
oauthServer oauth.Server
filter visibility.Filter
formatter text.Formatter
@@ -94,7 +94,7 @@ type processor struct {
}
// New returns a new account processor.
-func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *worker.Worker[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
+func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,
diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go
index 33b744250..d9ce68cc0 100644
--- a/internal/processing/account/account_test.go
+++ b/internal/processing/account/account_test.go
@@ -24,6 +24,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -35,7 +36,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -81,8 +81,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
suite.fromClientAPIChan <- msg
return nil
diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go
index 4b466a2d7..6779f59b7 100644
--- a/internal/processing/admin/admin.go
+++ b/internal/processing/admin/admin.go
@@ -23,13 +23,13 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing admin actions.
@@ -47,12 +47,12 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
db db.DB
}
// New returns a new admin processor.
-func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI]) Processor {
+func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,
diff --git a/internal/processing/media/media_test.go b/internal/processing/media/media_test.go
index af67b36b1..1149f2646 100644
--- a/internal/processing/media/media_test.go
+++ b/internal/processing/media/media_test.go
@@ -26,6 +26,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
@@ -33,7 +34,6 @@ import (
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -122,7 +122,7 @@ func (suite *MediaStandardTestSuite) mockTransportController() transport.Control
return response, nil
}
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}
diff --git a/internal/processing/processor.go b/internal/processing/processor.go
index 69f3100f9..d30f2f37e 100644
--- a/internal/processing/processor.go
+++ b/internal/processing/processor.go
@@ -25,6 +25,7 @@ import (
"codeberg.org/gruf/go-store/kv"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -44,7 +45,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor should be passed to api modules (see internal/apimodule/...). It is used for
@@ -237,8 +237,8 @@ type Processor interface {
// processor just implements the Processor interface
type processor struct {
- clientWorker *worker.Worker[messages.FromClientAPI]
- fedWorker *worker.Worker[messages.FromFederator]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
+ fedWorker *concurrency.WorkerPool[messages.FromFederator]
federator federation.Federator
tc typeutils.TypeConverter
@@ -271,8 +271,8 @@ func NewProcessor(
storage *kv.KVStore,
db db.DB,
emailSender email.Sender,
- clientWorker *worker.Worker[messages.FromClientAPI],
- fedWorker *worker.Worker[messages.FromFederator],
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
+ fedWorker *concurrency.WorkerPool[messages.FromFederator],
) Processor {
parseMentionFunc := GetParseMentionFunc(db, federator)
diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go
index 7e1972366..5946e6718 100644
--- a/internal/processing/processor_test.go
+++ b/internal/processing/processor_test.go
@@ -29,6 +29,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -217,8 +217,8 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
}, nil
})
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.transportController = testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go
index 207bffb30..e8b4a8268 100644
--- a/internal/processing/status/status.go
+++ b/internal/processing/status/status.go
@@ -22,6 +22,7 @@ import (
"context"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -29,7 +30,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing statuses.
@@ -74,12 +74,12 @@ type processor struct {
db db.DB
filter visibility.Filter
formatter text.Formatter
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new status processor.
-func New(db db.DB, tc typeutils.TypeConverter, clientWorker *worker.Worker[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
+func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
db: db,
diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go
index d2126f03d..17c68c0b6 100644
--- a/internal/processing/status/status_test.go
+++ b/internal/processing/status/status_test.go
@@ -21,6 +21,7 @@ package status_test
import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -42,7 +42,7 @@ type StatusStandardTestSuite struct {
storage *kv.KVStore
mediaManager media.Manager
federator federation.Federator
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -75,11 +75,11 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
- suite.clientWorker = worker.New[messages.FromClientAPI](-1, -1)
+ suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
suite.storage = testrig.NewTestStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
diff --git a/internal/transport/controller.go b/internal/transport/controller.go
index 56a922a8b..280d4bc0b 100644
--- a/internal/transport/controller.go
+++ b/internal/transport/controller.go
@@ -20,13 +20,17 @@ package transport
import (
"context"
- "crypto"
+ "crypto/rsa"
+ "crypto/x509"
"encoding/json"
"fmt"
"net/url"
- "sync"
+ "runtime/debug"
+ "time"
- "github.com/go-fed/httpsig"
+ "codeberg.org/gruf/go-byteutil"
+ "codeberg.org/gruf/go-cache/v2"
+ "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
@@ -37,109 +41,85 @@ import (
// Controller generates transports for use in making federation requests to other servers.
type Controller interface {
- NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
+ // NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key.
+ NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error)
+
+ // NewTransportForUsername searches for account with username, and returns result of .NewTransport().
NewTransportForUsername(ctx context.Context, username string) (Transport, error)
}
type controller struct {
- db db.DB
- clock pub.Clock
- client pub.HttpClient
- appAgent string
-
- // dereferenceFollowersShortcut is a shortcut to dereference followers of an
- // account on this instance, without making any external api/http calls.
- //
- // It is passed to new transports, and should only be invoked when the iri.Host == this host.
- dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
-
- // dereferenceUserShortcut is a shortcut to dereference followers an account on
- // this instance, without making any external api/http calls.
- //
- // It is passed to new transports, and should only be invoked when the iri.Host == this host.
- dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
-}
-
-func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
- return func(ctx context.Context, iri *url.URL) ([]byte, error) {
- followers, err := federatingDB.Followers(ctx, iri)
- if err != nil {
- return nil, err
- }
-
- i, err := streams.Serialize(followers)
- if err != nil {
- return nil, err
- }
-
- return json.Marshal(i)
- }
-}
-
-func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
- return func(ctx context.Context, iri *url.URL) ([]byte, error) {
- user, err := federatingDB.Get(ctx, iri)
- if err != nil {
- return nil, err
- }
-
- i, err := streams.Serialize(user)
- if err != nil {
- return nil, err
- }
-
- return json.Marshal(i)
- }
+ db db.DB
+ fedDB federatingdb.DB
+ clock pub.Clock
+ client pub.HttpClient
+ cache cache.Cache[string, *transport]
+ userAgent string
}
// NewController returns an implementation of the Controller interface for creating new transports
func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
applicationName := viper.GetString(config.Keys.ApplicationName)
host := viper.GetString(config.Keys.Host)
- appAgent := fmt.Sprintf("%s %s", applicationName, host)
- return &controller{
- db: db,
- clock: clock,
- client: client,
- appAgent: appAgent,
- dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB),
- dereferenceUserShortcut: dereferenceUserShortcut(federatingDB),
+ // Determine build information
+ build, _ := debug.ReadBuildInfo()
+
+ c := &controller{
+ db: db,
+ fedDB: federatingDB,
+ clock: clock,
+ client: client,
+ cache: cache.New[string, *transport](),
+ userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version),
}
+
+ // Transport cache has TTL=1hr freq=1m
+ c.cache.SetTTL(time.Hour, false)
+ if !c.cache.Start(time.Minute) {
+ logrus.Panic("failed to start transport controller cache")
+ }
+
+ return c
}
-// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key.
-func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) {
- prefs := []httpsig.Algorithm{httpsig.RSA_SHA256}
- digestAlgo := httpsig.DigestSha256
- getHeaders := []string{httpsig.RequestTarget, "host", "date"}
- postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"}
+func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) {
+ // Generate public key string for cache key
+ //
+ // NOTE: it is safe to use the public key as the cache
+ // key here as we are generating it ourselves from the
+ // private key. If we were simply using a public key
+ // provided as argument that would absolutely NOT be safe.
+ pubStr := privkeyToPublicStr(privkey)
- getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120)
- if err != nil {
- return nil, fmt.Errorf("error creating get signer: %s", err)
+ // First check for cached transport
+ transp, ok := c.cache.Get(pubStr)
+ if ok {
+ return transp, nil
}
- postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120)
- if err != nil {
- return nil, fmt.Errorf("error creating post signer: %s", err)
+ // Create the transport
+ transp = &transport{
+ controller: c,
+ pubKeyID: pubKeyID,
+ privkey: privkey,
}
- sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey)
+ // Cache this transport under pubkey
+ if !c.cache.Put(pubStr, transp) {
+ var cached *transport
- return &transport{
- client: c.client,
- appAgent: c.appAgent,
- gofedAgent: "(go-fed/activity v1.0.0)",
- clock: c.clock,
- pubKeyID: pubKeyID,
- privkey: privkey,
- sigTransport: sigTransport,
- getSigner: getSigner,
- getSignerMu: &sync.Mutex{},
- dereferenceFollowersShortcut: c.dereferenceFollowersShortcut,
- dereferenceUserShortcut: c.dereferenceUserShortcut,
- }, nil
+ cached, ok = c.cache.Get(pubStr)
+ if !ok {
+ // Some ridiculous race cond.
+ c.cache.Set(pubStr, transp)
+ } else {
+ // Use already cached
+ transp = cached
+ }
+ }
+
+ return transp, nil
}
func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
@@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
}
return transport, nil
}
+
+// dereferenceLocalFollowers is a shortcut to dereference followers of an
+// account on this instance, without making any external api/http calls.
+//
+// It is passed to new transports, and should only be invoked when the iri.Host == this host.
+func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) {
+ followers, err := c.fedDB.Followers(ctx, iri)
+ if err != nil {
+ return nil, err
+ }
+
+ i, err := streams.Serialize(followers)
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(i)
+}
+
+// dereferenceLocalUser is a shortcut to dereference followers an account on
+// this instance, without making any external api/http calls.
+//
+// It is passed to new transports, and should only be invoked when the iri.Host == this host.
+func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) {
+ user, err := c.fedDB.Get(ctx, iri)
+ if err != nil {
+ return nil, err
+ }
+
+ i, err := streams.Serialize(user)
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(i)
+}
+
+// privkeyToPublicStr will create a string representation of RSA public key from private.
+func privkeyToPublicStr(privkey *rsa.PrivateKey) string {
+ b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey)
+ return byteutil.B2S(b)
+}
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index fe17f7761..bacaa9b3a 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -19,13 +19,14 @@
package transport
import (
+ "bytes"
"context"
"fmt"
+ "net/http"
"net/url"
"strings"
"sync"
- "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
@@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
return nil
}
- logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
- return t.sigTransport.Deliver(ctx, b, to)
+ urlStr := to.String()
+
+ req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b))
+ if err != nil {
+ return err
+ }
+
+ req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
+ req.Header.Add("Accept-Charset", "utf-8")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", to.Host)
+
+ resp, err := t.POST(req, b)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if code := resp.StatusCode; code != http.StatusOK &&
+ code != http.StatusCreated && code != http.StatusAccepted {
+ return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status)
+ }
+
+ return nil
}
diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go
index 61d99c5c5..36157b673 100644
--- a/internal/transport/dereference.go
+++ b/internal/transport/dereference.go
@@ -20,32 +20,55 @@ package transport
import (
"context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
"net/url"
- "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
- l := logrus.WithField("func", "Dereference")
-
// if the request is to us, we can shortcut for certain URIs rather than going through
// the normal request flow, thereby saving time and energy
if iri.Host == viper.GetString(config.Keys.Host) {
if uris.IsFollowersPath(iri) {
// the request is for followers of one of our accounts, which we can shortcut
- return t.dereferenceFollowersShortcut(ctx, iri)
+ return t.controller.dereferenceLocalFollowers(ctx, iri)
}
if uris.IsUserPath(iri) {
// the request is for one of our accounts, which we can shortcut
- return t.dereferenceUserShortcut(ctx, iri)
+ return t.controller.dereferenceLocalUser(ctx, iri)
}
}
- // the request is either for a remote host or for us but we don't have a shortcut, so continue as normal
- l.Debugf("performing GET to %s", iri.String())
- return t.sigTransport.Dereference(ctx, iri)
+ // Build IRI just once
+ iriStr := iri.String()
+
+ // Prepare new HTTP request to endpoint
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
+ req.Header.Add("Accept-Charset", "utf-8")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", iri.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
+ if err != nil {
+ return nil, err
+ }
+ defer rsp.Body.Close()
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
+ }
+
+ return ioutil.ReadAll(rsp.Body)
}
diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go
index c64dced0f..1acbcc364 100644
--- a/internal/transport/derefinstance.go
+++ b/internal/transport/derefinstance.go
@@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts
}
func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
- l := logrus.WithField("func", "dereferenceByAPIV1Instance")
-
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: "api/v1/instance",
}
- l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
- if err != nil {
- return nil, err
- }
- req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
- req.Header.Set("Host", cleanIRI.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
- }
- b, err := ioutil.ReadAll(resp.Body)
+ // Build IRI just once
+ iriStr := cleanIRI.String()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
- if len(b) == 0 {
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", cleanIRI.Host)
+
+ resp, err := t.GET(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
+ }
+
+ b, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ } else if len(b) == 0 {
return nil, errors.New("response bytes was len 0")
}
@@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
}
func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
- l := logrus.WithField("func", "callNodeInfoWellKnown")
-
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: ".well-known/nodeinfo",
}
- l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
- if err != nil {
- return nil, err
- }
+ // Build IRI just once
+ iriStr := cleanIRI.String()
- req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
- req.Header.Set("Host", cleanIRI.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
- resp, err := t.client.Do(req)
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", cleanIRI.Host)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0")
}
@@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
}
func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
- l := logrus.WithField("func", "callNodeInfo")
+ // Build IRI just once
+ iriStr := iri.String()
- l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("callNodeInfo: response bytes was len 0")
}
diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go
index e3c86ce1e..8feb7ed20 100644
--- a/internal/transport/derefmedia.go
+++ b/internal/transport/derefmedia.go
@@ -24,34 +24,31 @@ import (
"io"
"net/http"
"net/url"
-
- "github.com/sirupsen/logrus"
)
func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) {
- l := logrus.WithField("func", "DereferenceMedia")
- l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ // Build IRI just once
+ iriStr := iri.String()
+
+ // Prepare HTTP request to this media's IRI
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
+ if err != nil {
+ return nil, 0, err
+ }
+ req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", iri.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
if err != nil {
return nil, 0, err
}
- req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
- req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, 0, err
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
}
- resp, err := t.client.Do(req)
- if err != nil {
- return nil, 0, err
- }
- if resp.StatusCode != http.StatusOK {
- return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
- }
- return resp.Body, int(resp.ContentLength), nil
+
+ return rsp.Body, int(rsp.ContentLength), nil
}
diff --git a/internal/transport/finger.go b/internal/transport/finger.go
index a71bbb51e..7554a242f 100644
--- a/internal/transport/finger.go
+++ b/internal/transport/finger.go
@@ -23,46 +23,36 @@ import (
"fmt"
"io/ioutil"
"net/http"
- "net/url"
-
- "github.com/sirupsen/logrus"
)
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
- l := logrus.WithField("func", "Finger")
- urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
- l.Debugf("performing GET to %s", urlString)
+ // Prepare URL string
+ urlStr := "https://" +
+ targetDomain +
+ "/.well-known/webfinger?resource=acct:" +
+ targetUsername + "@" + targetDomain
- iri, err := url.Parse(urlString)
- if err != nil {
- return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err)
- }
-
- l.Debugf("performing GET to %s", iri.String())
-
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ // Generate new GET request from URL string
+ req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
req.Header.Add("Accept", "application/jrd+json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
- req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", req.URL.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
if err != nil {
return nil, err
}
- resp, err := t.client.Do(req)
- if err != nil {
- return nil, err
+ defer rsp.Body.Close()
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status)
}
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
- }
- return ioutil.ReadAll(resp.Body)
+
+ return ioutil.ReadAll(rsp.Body)
}
diff --git a/internal/transport/signing.go b/internal/transport/signing.go
new file mode 100644
index 000000000..39896a2a8
--- /dev/null
+++ b/internal/transport/signing.go
@@ -0,0 +1,43 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package transport
+
+import (
+ "github.com/go-fed/httpsig"
+)
+
+var (
+ // http signer preferences
+ prefs = []httpsig.Algorithm{httpsig.RSA_SHA256}
+ digestAlgo = httpsig.DigestSha256
+ getHeaders = []string{httpsig.RequestTarget, "host", "date"}
+ postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"}
+)
+
+// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences.
+func NewGETSigner(expiresIn int64) (httpsig.Signer, error) {
+ sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn)
+ return sig, err
+}
+
+// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences.
+func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) {
+ sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn)
+ return sig, err
+}
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index 40c11ca17..c52686c43 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -21,11 +21,18 @@ package transport
import (
"context"
"crypto"
+ "crypto/x509"
+ "errors"
"io"
+ "net/http"
"net/url"
+ "strings"
"sync"
+ "time"
+ errorsv2 "codeberg.org/gruf/go-errors/v2"
"github.com/go-fed/httpsig"
+ "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -43,28 +50,148 @@ type Transport interface {
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
- // SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
- SigTransport() pub.Transport
}
// transport implements the Transport interface
type transport struct {
- client pub.HttpClient
- appAgent string
- gofedAgent string
- clock pub.Clock
- pubKeyID string
- privkey crypto.PrivateKey
- sigTransport *pub.HttpSigTransport
- getSigner httpsig.Signer
- getSignerMu *sync.Mutex
+ controller *controller
+ pubKeyID string
+ privkey crypto.PrivateKey
- // shortcuts for dereferencing things that exist on our instance without making an http call to ourself
-
- dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
- dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
+ signerExp time.Time
+ getSigner httpsig.Signer
+ postSigner httpsig.Signer
+ signerMu sync.Mutex
}
-func (t *transport) SigTransport() pub.Transport {
- return t.sigTransport
+// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodGet {
+ return nil, errors.New("must be GET request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signGET(r)
+ }, retryOn...)
+}
+
+// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodPost {
+ return nil, errors.New("must be POST request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signPOST(r, body)
+ }, retryOn...)
+}
+
+func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
+ const maxRetries = 5
+ backoff := time.Second * 2
+
+ // Start a log entry for this request
+ l := logrus.WithFields(logrus.Fields{
+ "pubKeyID": t.pubKeyID,
+ "method": r.Method,
+ "url": r.URL.String(),
+ })
+
+ for i := 0; i < maxRetries; i++ {
+ // Reset signing header fields
+ now := t.controller.clock.Now().UTC()
+ r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
+ r.Header.Del("Signature")
+ r.Header.Del("Digest")
+
+ // Perform request signing
+ if err := signer(r); err != nil {
+ return nil, err
+ }
+
+ l.Infof("performing request")
+
+ // Attempt to perform request
+ rsp, err := t.controller.client.Do(r)
+ if err == nil { //nolint shutup linter
+ // TooManyRequest means we need to slow
+ // down and retry our request. Codes over
+ // 500 generally indicate temp. outages.
+ if code := rsp.StatusCode; code < 500 &&
+ code != http.StatusTooManyRequests &&
+ !containsInt(retryOn, rsp.StatusCode) {
+ return rsp, nil
+ }
+
+ // Generate error from status code for logging
+ err = errors.New(`http response "` + rsp.Status + `"`)
+ } else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
+ // Return early if context has cancelled
+ return nil, err
+ } else if strings.Contains(err.Error(), "stopped after 10 redirects") {
+ // Don't bother if net/http returned after too many redirects
+ return nil, err
+ } else if errors.As(err, &x509.UnknownAuthorityError{}) {
+ // Unknown authority errors we do NOT recover from
+ return nil, err
+ }
+
+ l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
+
+ select {
+ // Request ctx cancelled
+ case <-r.Context().Done():
+ return nil, r.Context().Err()
+
+ // Backoff for some time
+ case <-time.After(backoff):
+ backoff *= 2
+ }
+ }
+
+ return nil, errors.New("transport reached max retries")
+}
+
+// signGET will safely sign an HTTP GET request.
+func (t *transport) signGET(r *http.Request) (err error) {
+ t.safesign(func() {
+ err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
+ })
+ return
+}
+
+// signPOST will safely sign an HTTP POST request for given body.
+func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
+ t.safesign(func() {
+ err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
+ })
+ return
+}
+
+// safesign will perform sign function within mutex protection,
+// and ensured that httpsig.Signers are up-to-date.
+func (t *transport) safesign(sign func()) {
+ // Perform within mu safety
+ t.signerMu.Lock()
+ defer t.signerMu.Unlock()
+
+ if now := time.Now(); now.After(t.signerExp) {
+ const expiry = 120
+
+ // Signers have expired and require renewal
+ t.getSigner, _ = NewGETSigner(expiry)
+ t.postSigner, _ = NewPOSTSigner(expiry)
+ t.signerExp = now.Add(time.Second * expiry)
+ }
+
+ // Perform signing
+ sign()
+}
+
+// containsInt checks if slice contains check.
+func containsInt(slice []int, check int) bool {
+ for _, i := range slice {
+ if i == check {
+ return true
+ }
+ }
+ return false
}
diff --git a/testrig/federatingdb.go b/testrig/federatingdb.go
index 3f71274ca..468cfbfd1 100644
--- a/testrig/federatingdb.go
+++ b/testrig/federatingdb.go
@@ -1,13 +1,13 @@
package testrig
import (
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestFederatingDB returns a federating DB with the underlying db
-func NewTestFederatingDB(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) federatingdb.DB {
+func NewTestFederatingDB(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federatingdb.DB {
return federatingdb.New(db, fedWorker)
}
diff --git a/testrig/federator.go b/testrig/federator.go
index 475ed3346..0546325ab 100644
--- a/testrig/federator.go
+++ b/testrig/federator.go
@@ -20,15 +20,15 @@ package testrig
import (
"codeberg.org/gruf/go-store/kv"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestFederator returns a federator with the given database and (mock!!) transport controller.
-func NewTestFederator(db db.DB, tc transport.Controller, storage *kv.KVStore, mediaManager media.Manager, fedWorker *worker.Worker[messages.FromFederator]) federation.Federator {
+func NewTestFederator(db db.DB, tc transport.Controller, storage *kv.KVStore, mediaManager media.Manager, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federation.Federator {
return federation.NewFederator(db, NewTestFederatingDB(db, fedWorker), tc, NewTestTypeConverter(db), mediaManager)
}
diff --git a/testrig/processor.go b/testrig/processor.go
index c0fbd8a74..15f9040f7 100644
--- a/testrig/processor.go
+++ b/testrig/processor.go
@@ -20,16 +20,16 @@ package testrig
import (
"codeberg.org/gruf/go-store/kv"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestProcessor returns a Processor suitable for testing purposes
-func NewTestProcessor(db db.DB, storage *kv.KVStore, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI], fedWorker *worker.Worker[messages.FromFederator]) processing.Processor {
+func NewTestProcessor(db db.DB, storage *kv.KVStore, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], fedWorker *concurrency.WorkerPool[messages.FromFederator]) processing.Processor {
return processing.NewProcessor(NewTestTypeConverter(db), federator, NewTestOauthServer(db), mediaManager, storage, db, emailSender, clientWorker, fedWorker)
}
diff --git a/testrig/testmodels.go b/testrig/testmodels.go
index 8894e562d..a74357feb 100644
--- a/testrig/testmodels.go
+++ b/testrig/testmodels.go
@@ -20,8 +20,6 @@ package testrig
import (
"bytes"
- "context"
- "crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@@ -29,7 +27,6 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
- "io/ioutil"
"net"
"net/http"
"net/url"
@@ -42,8 +39,7 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
+ "github.com/superseriousbusiness/gotosocial/internal/transport"
)
// NewTestTokens returns a map of tokens keyed according to which account the token belongs to.
@@ -1855,86 +1851,71 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin
}
}
-// GetSignatureForActivity does some sneaky sneaky work with a mock http client and a test transport controller, in order to derive
-// the HTTP Signature for the given activity, public key ID, private key, and destination.
-func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey crypto.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
- // create a client that basically just pulls the signature out of the request and sets it
- client := &mockHTTPClient{
- do: func(req *http.Request) (*http.Response, error) {
- signatureHeader = req.Header.Get("Signature")
- digestHeader = req.Header.Get("Digest")
- dateHeader = req.Header.Get("Date")
- r := ioutil.NopCloser(bytes.NewReader([]byte{})) // we only need this so the 'close' func doesn't nil out
- return &http.Response{
- StatusCode: 200,
- Body: r,
- }, nil
- },
- }
-
- // Create temporary federator worker for transport controller
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- _ = fedWorker.Start()
- defer func() { _ = fedWorker.Stop() }()
-
- // use the client to create a new transport
- c := NewTestTransportController(client, NewTestDB(), fedWorker)
- tp, err := c.NewTransport(pubKeyID, privkey)
- if err != nil {
- panic(err)
- }
-
+// GetSignatureForActivity prepares a mock HTTP request as if it were going to deliver activity to destination signed for privkey and pubKeyID, signs the request and returns the header values.
+func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
// convert the activity into json bytes
m, err := activity.Serialize()
if err != nil {
panic(err)
}
- bytes, err := json.Marshal(m)
+ b, err := json.Marshal(m)
if err != nil {
panic(err)
}
- // trigger the delivery function for the underlying signature transport, which will trigger the 'do' function of the recorder above
- if err := tp.SigTransport().Deliver(context.Background(), bytes, destination); err != nil {
+ // Prepare HTTP request signer
+ sig, err := transport.NewPOSTSigner(120)
+ if err != nil {
panic(err)
}
+ // Prepare a mock request ready for signing
+ r, err := http.NewRequest("POST", destination.String(), bytes.NewReader(b))
+ if err != nil {
+ panic(err)
+ }
+ r.Header.Set("Host", destination.Host)
+ r.Header.Set("Date", time.Now().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
+
+ // Sign this new HTTP request
+ if err := sig.SignRequest(privkey, pubKeyID, r, b); err != nil {
+ panic(err)
+ }
+
+ // Load signed data from request
+ signatureHeader = r.Header.Get("Signature")
+ digestHeader = r.Header.Get("Digest")
+ dateHeader = r.Header.Get("Date")
+
// headers should now be populated
return
}
-// GetSignatureForDereference does some sneaky sneaky work with a mock http client and a test transport controller, in order to derive
-// the HTTP Signature for the given derefence GET request using public key ID, private key, and destination.
-func GetSignatureForDereference(pubKeyID string, privkey crypto.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
- // create a client that basically just pulls the signature out of the request and sets it
- client := &mockHTTPClient{
- do: func(req *http.Request) (*http.Response, error) {
- signatureHeader = req.Header.Get("Signature")
- dateHeader = req.Header.Get("Date")
- r := ioutil.NopCloser(bytes.NewReader([]byte{})) // we only need this so the 'close' func doesn't nil out
- return &http.Response{
- StatusCode: 200,
- Body: r,
- }, nil
- },
- }
-
- // Create temporary federator worker for transport controller
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- _ = fedWorker.Start()
- defer func() { _ = fedWorker.Stop() }()
-
- // use the client to create a new transport
- c := NewTestTransportController(client, NewTestDB(), fedWorker)
- tp, err := c.NewTransport(pubKeyID, privkey)
+// GetSignatureForDereference prepares a mock HTTP request as if it were going to dereference destination signed for privkey and pubKeyID, signs the request and returns the header values.
+func GetSignatureForDereference(pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) {
+ // Prepare HTTP request signer
+ sig, err := transport.NewGETSigner(120)
if err != nil {
panic(err)
}
- // trigger the dereference function for the underlying signature transport, which will trigger the 'do' function of the recorder above
- if _, err := tp.SigTransport().Dereference(context.Background(), destination); err != nil {
+ // Prepare a mock request ready for signing
+ r, err := http.NewRequest("GET", destination.String(), nil)
+ if err != nil {
panic(err)
}
+ r.Header.Set("Host", destination.Host)
+ r.Header.Set("Date", time.Now().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
+
+ // Sign this new HTTP request
+ if err := sig.SignRequest(privkey, pubKeyID, r, nil); err != nil {
+ panic(err)
+ }
+
+ // Load signed data from request
+ signatureHeader = r.Header.Get("Signature")
+ digestHeader = r.Header.Get("Digest")
+ dateHeader = r.Header.Get("Date")
// headers should now be populated
return
diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go
index 943be7a61..7f4c7f890 100644
--- a/testrig/transportcontroller.go
+++ b/testrig/transportcontroller.go
@@ -24,11 +24,11 @@ import (
"net/http"
"github.com/superseriousbusiness/activity/pub"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// NewTestTransportController returns a test transport controller with the given http client.
@@ -40,7 +40,7 @@ import (
// Unlike the other test interfaces provided in this package, you'll probably want to call this function
// PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)
// basis.
-func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *worker.Worker[messages.FromFederator]) transport.Controller {
+func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) transport.Controller {
return transport.NewController(db, NewTestFederatingDB(db, fedWorker), &federation.Clock{}, client)
}
diff --git a/vendor/codeberg.org/gruf/go-byteutil/bytes.go b/vendor/codeberg.org/gruf/go-byteutil/bytes.go
index 4fbfe001d..eb57a90de 100644
--- a/vendor/codeberg.org/gruf/go-byteutil/bytes.go
+++ b/vendor/codeberg.org/gruf/go-byteutil/bytes.go
@@ -16,11 +16,30 @@ func Copy(b []byte) []byte {
}
// B2S returns a string representation of []byte without allocation.
+//
+// According to the Go spec strings are immutable and byte slices are not. The way this gets implemented is strings under the hood are:
+// type StringHeader struct {
+// Data uintptr
+// Len int
+// }
+//
+// while slices are:
+// type SliceHeader struct {
+// Data uintptr
+// Len int
+// Cap int
+// }
+// because being mutable, you can change the data, length etc, but the string has to promise to be read-only to all who get copies of it.
+//
+// So in practice when you do a conversion of `string(byteSlice)` it actually performs an allocation because it has to copy the contents of the byte slice into a safe read-only state.
+//
+// Being that the shared fields are in the same struct indices (no different offsets), means that if you have a byte slice you can "forcibly" cast it to a string. Which in a lot of situations can be risky, because then it means you have a string that is NOT immutable, as if someone changes the data in the originating byte slice then the string will reflect that change! Now while this does seem hacky, and it _kind_ of is, it is something that you see performed in the standard library. If you look at the definition for `strings.Builder{}.String()` you'll see this :)
func B2S(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// S2B returns a []byte representation of string without allocation (minus slice header).
+// See B2S() code comment, and this function's implementation for a better understanding.
func S2B(s string) []byte {
var b []byte
diff --git a/vendor/codeberg.org/gruf/go-cache/v2/LICENSE b/vendor/codeberg.org/gruf/go-cache/v2/LICENSE
new file mode 100644
index 000000000..b7c4417ac
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-cache/v2/LICENSE
@@ -0,0 +1,9 @@
+MIT License
+
+Copyright (c) 2021 gruf
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/vendor/codeberg.org/gruf/go-cache/v2/README.md b/vendor/codeberg.org/gruf/go-cache/v2/README.md
new file mode 100644
index 000000000..69eee7039
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-cache/v2/README.md
@@ -0,0 +1,3 @@
+# go-cache
+
+A TTL cache designed to be used as a base for your own customizations, or used straight out of the box
\ No newline at end of file
diff --git a/vendor/codeberg.org/gruf/go-cache/v2/cache.go b/vendor/codeberg.org/gruf/go-cache/v2/cache.go
new file mode 100644
index 000000000..89ad314ee
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-cache/v2/cache.go
@@ -0,0 +1,67 @@
+package cache
+
+import "time"
+
+// Cache represents a TTL cache with customizable callbacks, it
+// exists here to abstract away the "unsafe" methods in the case that
+// you do not want your own implementation atop TTLCache{}.
+type Cache[Key comparable, Value any] interface {
+ // Start will start the cache background eviction routine with given sweep frequency.
+ // If already running or a freq <= 0 provided, this is a no-op. This will block until
+ // the eviction routine has started
+ Start(freq time.Duration) bool
+
+ // Stop will stop cache background eviction routine. If not running this is a no-op. This
+ // will block until the eviction routine has stopped
+ Stop() bool
+
+ // SetEvictionCallback sets the eviction callback to the provided hook
+ SetEvictionCallback(hook Hook[Key, Value])
+
+ // SetInvalidateCallback sets the invalidate callback to the provided hook
+ SetInvalidateCallback(hook Hook[Key, Value])
+
+ // SetTTL sets the cache item TTL. Update can be specified to force updates of existing items in
+ // the cache, this will simply add the change in TTL to their current expiry time
+ SetTTL(ttl time.Duration, update bool)
+
+ // Get fetches the value with key from the cache, extending its TTL
+ Get(key Key) (value Value, ok bool)
+
+ // Put attempts to place the value at key in the cache, doing nothing if
+ // a value with this key already exists. Returned bool is success state
+ Put(key Key, value Value) bool
+
+ // Set places the value at key in the cache. This will overwrite any
+ // existing value, and call the update callback so. Existing values
+ // will have their TTL extended upon update
+ Set(key Key, value Value)
+
+ // CAS will attempt to perform a CAS operation on 'key', using provided
+ // comparison and swap values. Returned bool is success.
+ CAS(key Key, cmp, swp Value) bool
+
+ // Swap will attempt to perform a swap on 'key', replacing the value there
+ // and returning the existing value. If no value exists for key, this will
+ // set the value and return the zero value for V.
+ Swap(key Key, swp Value) Value
+
+ // Has checks the cache for a value with key, this will not update TTL
+ Has(key Key) bool
+
+ // Invalidate deletes a value from the cache, calling the invalidate callback
+ Invalidate(key Key) bool
+
+ // Clear empties the cache, calling the invalidate callback
+ Clear()
+
+ // Size returns the current size of the cache
+ Size() int
+}
+
+// New returns a new initialized Cache.
+func New[K comparable, V any]() Cache[K, V] {
+ c := TTLCache[K, V]{}
+ c.Init()
+ return &c
+}
diff --git a/vendor/codeberg.org/gruf/go-cache/v2/compare.go b/vendor/codeberg.org/gruf/go-cache/v2/compare.go
new file mode 100644
index 000000000..749d6c05f
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-cache/v2/compare.go
@@ -0,0 +1,23 @@
+package cache
+
+import (
+ "reflect"
+)
+
+type Comparable interface {
+ Equal(any) bool
+}
+
+// Compare returns whether 2 values are equal using the Comparable
+// interface, or failing that falls back to use reflect.DeepEqual().
+func Compare(i1, i2 any) bool {
+ c1, ok1 := i1.(Comparable)
+ if ok1 {
+ return c1.Equal(i2)
+ }
+ c2, ok2 := i2.(Comparable)
+ if ok2 {
+ return c2.Equal(i1)
+ }
+ return reflect.DeepEqual(i1, i2)
+}
diff --git a/vendor/codeberg.org/gruf/go-cache/v2/hook.go b/vendor/codeberg.org/gruf/go-cache/v2/hook.go
new file mode 100644
index 000000000..45ef8c92e
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-cache/v2/hook.go
@@ -0,0 +1,6 @@
+package cache
+
+// Hook defines a function hook that can be supplied as a callback.
+type Hook[Key comparable, Value any] func(key Key, value Value)
+
+func emptyHook[K comparable, V any](K, V) {}
diff --git a/vendor/codeberg.org/gruf/go-cache/v2/lookup.go b/vendor/codeberg.org/gruf/go-cache/v2/lookup.go
new file mode 100644
index 000000000..cddd1317d
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-cache/v2/lookup.go
@@ -0,0 +1,214 @@
+package cache
+
+// LookupCfg is the LookupCache configuration.
+type LookupCfg[OGKey, AltKey comparable, Value any] struct {
+ // RegisterLookups is called on init to register lookups
+ // within LookupCache's internal LookupMap
+ RegisterLookups func(*LookupMap[OGKey, AltKey])
+
+ // AddLookups is called on each addition to the cache, to
+ // set any required additional key lookups for supplied item
+ AddLookups func(*LookupMap[OGKey, AltKey], Value)
+
+ // DeleteLookups is called on each eviction/invalidation of
+ // an item in the cache, to remove any unused key lookups
+ DeleteLookups func(*LookupMap[OGKey, AltKey], Value)
+}
+
+// LookupCache is a cache built on-top of TTLCache, providing multi-key
+// lookups for items in the cache by means of additional lookup maps. These
+// maps simply store additional keys => original key, with hook-ins to automatically
+// call user supplied functions on adding an item, or on updating/deleting an
+// item to keep the LookupMap up-to-date.
+type LookupCache[OGKey, AltKey comparable, Value any] interface {
+ Cache[OGKey, Value]
+
+ // GetBy fetches a cached value by supplied lookup identifier and key
+ GetBy(lookup string, key AltKey) (value Value, ok bool)
+
+ // CASBy will attempt to perform a CAS operation on supplied lookup identifier and key
+ CASBy(lookup string, key AltKey, cmp, swp Value) bool
+
+ // SwapBy will attempt to perform a swap operation on supplied lookup identifier and key
+ SwapBy(lookup string, key AltKey, swp Value) Value
+
+ // HasBy checks if a value is cached under supplied lookup identifier and key
+ HasBy(lookup string, key AltKey) bool
+
+ // InvalidateBy invalidates a value by supplied lookup identifier and key
+ InvalidateBy(lookup string, key AltKey) bool
+}
+
+type lookupTTLCache[OK, AK comparable, V any] struct {
+ config LookupCfg[OK, AK, V]
+ lookup LookupMap[OK, AK]
+ TTLCache[OK, V]
+}
+
+// NewLookup returns a new initialized LookupCache.
+func NewLookup[OK, AK comparable, V any](cfg LookupCfg[OK, AK, V]) LookupCache[OK, AK, V] {
+ switch {
+ case cfg.RegisterLookups == nil:
+ panic("cache: nil lookups register function")
+ case cfg.AddLookups == nil:
+ panic("cache: nil lookups add function")
+ case cfg.DeleteLookups == nil:
+ panic("cache: nil delete lookups function")
+ }
+ c := lookupTTLCache[OK, AK, V]{config: cfg}
+ c.TTLCache.Init()
+ c.lookup.lookup = make(map[string]map[AK]OK)
+ c.config.RegisterLookups(&c.lookup)
+ c.SetEvictionCallback(nil)
+ c.SetInvalidateCallback(nil)
+ c.lookup.initd = true
+ return &c
+}
+
+func (c *lookupTTLCache[OK, AK, V]) SetEvictionCallback(hook Hook[OK, V]) {
+ if hook == nil {
+ hook = emptyHook[OK, V]
+ }
+ c.TTLCache.SetEvictionCallback(func(key OK, value V) {
+ hook(key, value)
+ c.config.DeleteLookups(&c.lookup, value)
+ })
+}
+
+func (c *lookupTTLCache[OK, AK, V]) SetInvalidateCallback(hook Hook[OK, V]) {
+ if hook == nil {
+ hook = emptyHook[OK, V]
+ }
+ c.TTLCache.SetInvalidateCallback(func(key OK, value V) {
+ hook(key, value)
+ c.config.DeleteLookups(&c.lookup, value)
+ })
+}
+
+func (c *lookupTTLCache[OK, AK, V]) GetBy(lookup string, key AK) (V, bool) {
+ c.Lock()
+ origKey, ok := c.lookup.Get(lookup, key)
+ if !ok {
+ c.Unlock()
+ var value V
+ return value, false
+ }
+ v, ok := c.GetUnsafe(origKey)
+ c.Unlock()
+ return v, ok
+}
+
+func (c *lookupTTLCache[OK, AK, V]) Put(key OK, value V) bool {
+ c.Lock()
+ put := c.PutUnsafe(key, value)
+ if put {
+ c.config.AddLookups(&c.lookup, value)
+ }
+ c.Unlock()
+ return put
+}
+
+func (c *lookupTTLCache[OK, AK, V]) Set(key OK, value V) {
+ c.Lock()
+ defer c.Unlock()
+ c.SetUnsafe(key, value)
+ c.config.AddLookups(&c.lookup, value)
+}
+
+func (c *lookupTTLCache[OK, AK, V]) CASBy(lookup string, key AK, cmp, swp V) bool {
+ c.Lock()
+ defer c.Unlock()
+ origKey, ok := c.lookup.Get(lookup, key)
+ if !ok {
+ return false
+ }
+ return c.CASUnsafe(origKey, cmp, swp)
+}
+
+func (c *lookupTTLCache[OK, AK, V]) SwapBy(lookup string, key AK, swp V) V {
+ c.Lock()
+ defer c.Unlock()
+ origKey, ok := c.lookup.Get(lookup, key)
+ if !ok {
+ var value V
+ return value
+ }
+ return c.SwapUnsafe(origKey, swp)
+}
+
+func (c *lookupTTLCache[OK, AK, V]) HasBy(lookup string, key AK) bool {
+ c.Lock()
+ has := c.lookup.Has(lookup, key)
+ c.Unlock()
+ return has
+}
+
+func (c *lookupTTLCache[OK, AK, V]) InvalidateBy(lookup string, key AK) bool {
+ c.Lock()
+ defer c.Unlock()
+ origKey, ok := c.lookup.Get(lookup, key)
+ if !ok {
+ return false
+ }
+ c.InvalidateUnsafe(origKey)
+ return true
+}
+
+// LookupMap is a structure that provides lookups for
+// keys to primary keys under supplied lookup identifiers.
+// This is essentially a wrapper around map[string](map[K1]K2).
+type LookupMap[OK comparable, AK comparable] struct {
+ initd bool
+ lookup map[string](map[AK]OK)
+}
+
+// RegisterLookup registers a lookup identifier in the LookupMap,
+// note this can only be doing during the cfg.RegisterLookups() hook.
+func (l *LookupMap[OK, AK]) RegisterLookup(id string) {
+ if l.initd {
+ panic("cache: cannot register lookup after initialization")
+ } else if _, ok := l.lookup[id]; ok {
+ panic("cache: lookup mapping already exists for identifier")
+ }
+ l.lookup[id] = make(map[AK]OK, 100)
+}
+
+// Get fetches an entry's primary key for lookup identifier and key.
+func (l *LookupMap[OK, AK]) Get(id string, key AK) (OK, bool) {
+ keys, ok := l.lookup[id]
+ if !ok {
+ var key OK
+ return key, false
+ }
+ origKey, ok := keys[key]
+ return origKey, ok
+}
+
+// Set adds a lookup to the LookupMap under supplied lookup identifier,
+// linking supplied key to the supplied primary (original) key.
+func (l *LookupMap[OK, AK]) Set(id string, key AK, origKey OK) {
+ keys, ok := l.lookup[id]
+ if !ok {
+ panic("cache: invalid lookup identifier")
+ }
+ keys[key] = origKey
+}
+
+// Has checks if there exists a lookup for supplied identifier and key.
+func (l *LookupMap[OK, AK]) Has(id string, key AK) bool {
+ keys, ok := l.lookup[id]
+ if !ok {
+ return false
+ }
+ _, ok = keys[key]
+ return ok
+}
+
+// Delete removes a lookup from LookupMap with supplied identifier and key.
+func (l *LookupMap[OK, AK]) Delete(id string, key AK) {
+ keys, ok := l.lookup[id]
+ if !ok {
+ return
+ }
+ delete(keys, key)
+}
diff --git a/vendor/codeberg.org/gruf/go-cache/v2/ttl.go b/vendor/codeberg.org/gruf/go-cache/v2/ttl.go
new file mode 100644
index 000000000..42f28b53b
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-cache/v2/ttl.go
@@ -0,0 +1,333 @@
+package cache
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ "codeberg.org/gruf/go-runners"
+)
+
+// TTLCache is the underlying Cache implementation, providing both the base
+// Cache interface and access to "unsafe" methods so that you may build your
+// customized caches ontop of this structure.
+type TTLCache[Key comparable, Value any] struct {
+ cache map[Key](*entry[Value])
+ evict Hook[Key, Value] // the evict hook is called when an item is evicted from the cache, includes manual delete
+ invalid Hook[Key, Value] // the invalidate hook is called when an item's data in the cache is invalidated
+ ttl time.Duration // ttl is the item TTL
+ svc runners.Service // svc manages running of the cache eviction routine
+ mu sync.Mutex // mu protects TTLCache for concurrent access
+}
+
+// Init performs Cache initialization, this MUST be called.
+func (c *TTLCache[K, V]) Init() {
+ c.cache = make(map[K](*entry[V]), 100)
+ c.evict = emptyHook[K, V]
+ c.invalid = emptyHook[K, V]
+ c.ttl = time.Minute * 5
+}
+
+func (c *TTLCache[K, V]) Start(freq time.Duration) bool {
+ // Nothing to start
+ if freq <= 0 {
+ return false
+ }
+
+ // Track state of starting
+ done := make(chan struct{})
+ started := false
+
+ go func() {
+ ran := c.svc.Run(func(ctx context.Context) {
+ // Successfully started
+ started = true
+ close(done)
+
+ // start routine
+ c.run(ctx, freq)
+ })
+
+ // failed to start
+ if !ran {
+ close(done)
+ }
+ }()
+
+ <-done
+ return started
+}
+
+func (c *TTLCache[K, V]) Stop() bool {
+ return c.svc.Stop()
+}
+
+func (c *TTLCache[K, V]) run(ctx context.Context, freq time.Duration) {
+ t := time.NewTimer(freq)
+ for {
+ select {
+ // we got stopped
+ case <-ctx.Done():
+ if !t.Stop() {
+ <-t.C
+ }
+ return
+
+ // next tick
+ case <-t.C:
+ c.sweep()
+ t.Reset(freq)
+ }
+ }
+}
+
+// sweep attempts to evict expired items (with callback!) from cache.
+func (c *TTLCache[K, V]) sweep() {
+ // Lock and defer unlock (in case of hook panic)
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Fetch current time for TTL check
+ now := time.Now()
+
+ // Sweep the cache for old items!
+ for key, item := range c.cache {
+ if now.After(item.expiry) {
+ c.evict(key, item.value)
+ delete(c.cache, key)
+ }
+ }
+}
+
+// Lock locks the cache mutex.
+func (c *TTLCache[K, V]) Lock() {
+ c.mu.Lock()
+}
+
+// Unlock unlocks the cache mutex.
+func (c *TTLCache[K, V]) Unlock() {
+ c.mu.Unlock()
+}
+
+func (c *TTLCache[K, V]) SetEvictionCallback(hook Hook[K, V]) {
+ // Ensure non-nil hook
+ if hook == nil {
+ hook = emptyHook[K, V]
+ }
+
+ // Safely set evict hook
+ c.Lock()
+ c.evict = hook
+ c.Unlock()
+}
+
+func (c *TTLCache[K, V]) SetInvalidateCallback(hook Hook[K, V]) {
+ // Ensure non-nil hook
+ if hook == nil {
+ hook = emptyHook[K, V]
+ }
+
+ // Safely set invalidate hook
+ c.Lock()
+ c.invalid = hook
+ c.Unlock()
+}
+
+func (c *TTLCache[K, V]) SetTTL(ttl time.Duration, update bool) {
+ // Safely update TTL
+ c.Lock()
+ diff := ttl - c.ttl
+ c.ttl = ttl
+
+ if update {
+ // Update existing cache entries
+ for _, entry := range c.cache {
+ entry.expiry.Add(diff)
+ }
+ }
+
+ // We're done
+ c.Unlock()
+}
+
+func (c *TTLCache[K, V]) Get(key K) (V, bool) {
+ c.Lock()
+ value, ok := c.GetUnsafe(key)
+ c.Unlock()
+ return value, ok
+}
+
+// GetUnsafe is the mutex-unprotected logic for Cache.Get().
+func (c *TTLCache[K, V]) GetUnsafe(key K) (V, bool) {
+ item, ok := c.cache[key]
+ if !ok {
+ var value V
+ return value, false
+ }
+ item.expiry = time.Now().Add(c.ttl)
+ return item.value, true
+}
+
+func (c *TTLCache[K, V]) Put(key K, value V) bool {
+ c.Lock()
+ success := c.PutUnsafe(key, value)
+ c.Unlock()
+ return success
+}
+
+// PutUnsafe is the mutex-unprotected logic for Cache.Put().
+func (c *TTLCache[K, V]) PutUnsafe(key K, value V) bool {
+ // If already cached, return
+ if _, ok := c.cache[key]; ok {
+ return false
+ }
+
+ // Create new cached item
+ c.cache[key] = &entry[V]{
+ value: value,
+ expiry: time.Now().Add(c.ttl),
+ }
+
+ return true
+}
+
+func (c *TTLCache[K, V]) Set(key K, value V) {
+ c.Lock()
+ defer c.Unlock() // defer in case of hook panic
+ c.SetUnsafe(key, value)
+}
+
+// SetUnsafe is the mutex-unprotected logic for Cache.Set(), it calls externally-set functions.
+func (c *TTLCache[K, V]) SetUnsafe(key K, value V) {
+ item, ok := c.cache[key]
+ if ok {
+ // call invalidate hook
+ c.invalid(key, item.value)
+ } else {
+ // alloc new item
+ item = &entry[V]{}
+ c.cache[key] = item
+ }
+
+ // Update the item + expiry
+ item.value = value
+ item.expiry = time.Now().Add(c.ttl)
+}
+
+func (c *TTLCache[K, V]) CAS(key K, cmp V, swp V) bool {
+ c.Lock()
+ ok := c.CASUnsafe(key, cmp, swp)
+ c.Unlock()
+ return ok
+}
+
+// CASUnsafe is the mutex-unprotected logic for Cache.CAS().
+func (c *TTLCache[K, V]) CASUnsafe(key K, cmp V, swp V) bool {
+ // Check for item
+ item, ok := c.cache[key]
+ if !ok || !Compare(item.value, cmp) {
+ return false
+ }
+
+ // Invalidate item
+ c.invalid(key, item.value)
+
+ // Update item + expiry
+ item.value = swp
+ item.expiry = time.Now().Add(c.ttl)
+
+ return ok
+}
+
+func (c *TTLCache[K, V]) Swap(key K, swp V) V {
+ c.Lock()
+ old := c.SwapUnsafe(key, swp)
+ c.Unlock()
+ return old
+}
+
+// SwapUnsafe is the mutex-unprotected logic for Cache.Swap().
+func (c *TTLCache[K, V]) SwapUnsafe(key K, swp V) V {
+ // Check for item
+ item, ok := c.cache[key]
+ if !ok {
+ var value V
+ return value
+ }
+
+ // invalidate old item
+ c.invalid(key, item.value)
+ old := item.value
+
+ // update item + expiry
+ item.value = swp
+ item.expiry = time.Now().Add(c.ttl)
+
+ return old
+}
+
+func (c *TTLCache[K, V]) Has(key K) bool {
+ c.Lock()
+ ok := c.HasUnsafe(key)
+ c.Unlock()
+ return ok
+}
+
+// HasUnsafe is the mutex-unprotected logic for Cache.Has().
+func (c *TTLCache[K, V]) HasUnsafe(key K) bool {
+ _, ok := c.cache[key]
+ return ok
+}
+
+func (c *TTLCache[K, V]) Invalidate(key K) bool {
+ c.Lock()
+ defer c.Unlock()
+ return c.InvalidateUnsafe(key)
+}
+
+// InvalidateUnsafe is mutex-unprotected logic for Cache.Invalidate().
+func (c *TTLCache[K, V]) InvalidateUnsafe(key K) bool {
+ // Check if we have item with key
+ item, ok := c.cache[key]
+ if !ok {
+ return false
+ }
+
+ // Call hook, remove from cache
+ c.invalid(key, item.value)
+ delete(c.cache, key)
+ return true
+}
+
+func (c *TTLCache[K, V]) Clear() {
+ c.Lock()
+ defer c.Unlock()
+ c.ClearUnsafe()
+}
+
+// ClearUnsafe is mutex-unprotected logic for Cache.Clean().
+func (c *TTLCache[K, V]) ClearUnsafe() {
+ for key, item := range c.cache {
+ c.invalid(key, item.value)
+ delete(c.cache, key)
+ }
+}
+
+func (c *TTLCache[K, V]) Size() int {
+ c.Lock()
+ sz := c.SizeUnsafe()
+ c.Unlock()
+ return sz
+}
+
+// SizeUnsafe is mutex unprotected logic for Cache.Size().
+func (c *TTLCache[K, V]) SizeUnsafe() int {
+ return len(c.cache)
+}
+
+// entry represents an item in the cache, with
+// it's currently calculated expiry time.
+type entry[Value any] struct {
+ value Value
+ expiry time.Time
+}
diff --git a/vendor/modules.txt b/vendor/modules.txt
index 9631a73df..67393f962 100644
--- a/vendor/modules.txt
+++ b/vendor/modules.txt
@@ -4,9 +4,12 @@ codeberg.org/gruf/go-bitutil
# codeberg.org/gruf/go-bytes v1.0.2
## explicit; go 1.14
codeberg.org/gruf/go-bytes
-# codeberg.org/gruf/go-byteutil v1.0.0
+# codeberg.org/gruf/go-byteutil v1.0.1
## explicit; go 1.16
codeberg.org/gruf/go-byteutil
+# codeberg.org/gruf/go-cache/v2 v2.0.1
+## explicit; go 1.18
+codeberg.org/gruf/go-cache/v2
# codeberg.org/gruf/go-debug v1.1.2
## explicit; go 1.16
codeberg.org/gruf/go-debug