mirror of
1
Fork 0

[feature] persist worker queues to db (#3042)

* persist queued worker tasks to database on shutdown, fill worker queues from database on startup

* ensure the tasks are sorted by creation time before pushing them

* add migration to insert WorkerTask{} into database, add test for worker task persistence

* add test for recovering worker queues from database

* quick tweak

* whoops we ended up with double cleaner job scheduling

* insert each task separately, because bun is throwing some reflection error??

* add specific checking of cancelled worker contexts

* add http request signing to deliveries recovered from database

* add test for outgoing public key ID being correctly set on delivery

* replace select with Queue.PopCtx()

* get rid of loop now we don't use it

* remove field now we don't use it

* ensure that signing func is set

* header values weren't being copied over 🤦

* use ptr for httpclient.Request in delivery

* move worker queue filling to later in server init process

* fix rebase issues

* make logging less shouty

* use slices.Delete() instead of copying / reslicing

* have database return tasks in ascending order instead of sorting them

* add a 1 minute timeout to persisting worker queues
This commit is contained in:
kim 2024-07-30 11:58:31 +00:00 committed by GitHub
parent 42932f9820
commit 87cff71af9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1191 additions and 93 deletions

View File

@ -87,9 +87,9 @@ var Start action.GTSAction = func(ctx context.Context) error {
// defer function for safe shutdown // defer function for safe shutdown
// depending on what services were // depending on what services were
// managed to be started. // managed to be started.
state = new(state.State)
state = new(state.State) route *router.Router
route *router.Router process *processing.Processor
) )
defer func() { defer func() {
@ -125,6 +125,23 @@ var Start action.GTSAction = func(ctx context.Context) error {
} }
} }
if process != nil {
const timeout = time.Minute
// Use a new timeout context to ensure
// persisting queued tasks does not fail!
// The main ctx is very likely canceled.
ctx := context.WithoutCancel(ctx)
ctx, cncl := context.WithTimeout(ctx, timeout)
defer cncl()
// Now that all the "moving" components have been stopped,
// persist any remaining queued worker tasks to the database.
if err := process.Admin().PersistWorkerQueues(ctx); err != nil {
log.Errorf(ctx, "error persisting worker queues: %v", err)
}
}
if state.DB != nil { if state.DB != nil {
// Lastly, if database service was started, // Lastly, if database service was started,
// ensure it gets closed now all else stopped. // ensure it gets closed now all else stopped.
@ -270,7 +287,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
// Create the processor using all the // Create the processor using all the
// other services we've created so far. // other services we've created so far.
processor := processing.NewProcessor( process = processing.NewProcessor(
cleaner, cleaner,
typeConverter, typeConverter,
federator, federator,
@ -286,14 +303,14 @@ var Start action.GTSAction = func(ctx context.Context) error {
state.Workers.Client.Init(messages.ClientMsgIndices()) state.Workers.Client.Init(messages.ClientMsgIndices())
state.Workers.Federator.Init(messages.FederatorMsgIndices()) state.Workers.Federator.Init(messages.FederatorMsgIndices())
state.Workers.Delivery.Init(client) state.Workers.Delivery.Init(client)
state.Workers.Client.Process = processor.Workers().ProcessFromClientAPI state.Workers.Client.Process = process.Workers().ProcessFromClientAPI
state.Workers.Federator.Process = processor.Workers().ProcessFromFediAPI state.Workers.Federator.Process = process.Workers().ProcessFromFediAPI
// Now start workers! // Now start workers!
state.Workers.Start() state.Workers.Start()
// Schedule notif tasks for all existing poll expiries. // Schedule notif tasks for all existing poll expiries.
if err := processor.Polls().ScheduleAll(ctx); err != nil { if err := process.Polls().ScheduleAll(ctx); err != nil {
return fmt.Errorf("error scheduling poll expiries: %w", err) return fmt.Errorf("error scheduling poll expiries: %w", err)
} }
@ -303,7 +320,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
} }
// Run advanced migrations. // Run advanced migrations.
if err := processor.AdvancedMigrations().Migrate(ctx); err != nil { if err := process.AdvancedMigrations().Migrate(ctx); err != nil {
return err return err
} }
@ -370,7 +387,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
// attach global no route / 404 handler to the router // attach global no route / 404 handler to the router
route.AttachNoRouteHandler(func(c *gin.Context) { route.AttachNoRouteHandler(func(c *gin.Context) {
apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGetV1) apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), process.InstanceGetV1)
}) })
// build router modules // build router modules
@ -393,15 +410,15 @@ var Start action.GTSAction = func(ctx context.Context) error {
} }
var ( var (
authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths authModule = api.NewAuth(dbService, process, idp, routerSession, sessionName) // auth/oauth paths
clientModule = api.NewClient(state, processor) // api client endpoints clientModule = api.NewClient(state, process) // api client endpoints
metricsModule = api.NewMetrics() // Metrics endpoints metricsModule = api.NewMetrics() // Metrics endpoints
healthModule = api.NewHealth(dbService.Ready) // Health check endpoints healthModule = api.NewHealth(dbService.Ready) // Health check endpoints
fileserverModule = api.NewFileserver(processor) // fileserver endpoints fileserverModule = api.NewFileserver(process) // fileserver endpoints
wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints wellKnownModule = api.NewWellKnown(process) // .well-known endpoints
nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint nodeInfoModule = api.NewNodeInfo(process) // nodeinfo endpoint
activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints activityPubModule = api.NewActivityPub(dbService, process) // ActivityPub endpoints
webModule = web.New(dbService, processor) // web pages + user profiles + settings panels etc webModule = web.New(dbService, process) // web pages + user profiles + settings panels etc
) )
// create required middleware // create required middleware
@ -416,10 +433,11 @@ var Start action.GTSAction = func(ctx context.Context) error {
// throttling // throttling
cpuMultiplier := config.GetAdvancedThrottlingMultiplier() cpuMultiplier := config.GetAdvancedThrottlingMultiplier()
retryAfter := config.GetAdvancedThrottlingRetryAfter() retryAfter := config.GetAdvancedThrottlingRetryAfter()
clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api
s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // server-to-server (AP) s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter)
fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis // server-to-server (AP)
pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis
pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately
gzip := middleware.Gzip() // applied to all except fileserver gzip := middleware.Gzip() // applied to all except fileserver
@ -442,6 +460,11 @@ var Start action.GTSAction = func(ctx context.Context) error {
return fmt.Errorf("error starting router: %w", err) return fmt.Errorf("error starting router: %w", err)
} }
// Fill worker queues from persisted task data in database.
if err := process.Admin().FillWorkerQueues(ctx); err != nil {
return fmt.Errorf("error filling worker queues: %w", err)
}
// catch shutdown signals from the operating system // catch shutdown signals from the operating system
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

View File

@ -84,6 +84,7 @@ type DBService struct {
db.Timeline db.Timeline
db.User db.User
db.Tombstone db.Tombstone
db.WorkerTask
db *bun.DB db *bun.DB
} }
@ -302,6 +303,9 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db, db: db,
state: state, state: state,
}, },
WorkerTask: &workerTaskDB{
db: db,
},
db: db, db: db,
} }

View File

@ -0,0 +1,51 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// WorkerTask table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.WorkerTask{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,58 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type workerTaskDB struct{ db *bun.DB }
func (w *workerTaskDB) GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error) {
var tasks []*gtsmodel.WorkerTask
if err := w.db.NewSelect().
Model(&tasks).
OrderExpr("? ASC", bun.Ident("created_at")).
Scan(ctx); err != nil {
return nil, err
}
return tasks, nil
}
func (w *workerTaskDB) PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error {
var errs []error
for _, task := range tasks {
_, err := w.db.NewInsert().Model(task).Exec(ctx)
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
func (w *workerTaskDB) DeleteWorkerTaskByID(ctx context.Context, id uint) error {
_, err := w.db.NewDelete().
Table("worker_tasks").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return err
}

View File

@ -56,4 +56,5 @@ type DB interface {
Timeline Timeline
User User
Tombstone Tombstone
WorkerTask
} }

35
internal/db/workertask.go Normal file
View File

@ -0,0 +1,35 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type WorkerTask interface {
// GetWorkerTasks fetches all persisted worker tasks from the database.
GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error)
// PutWorkerTasks persists the given worker tasks to the database.
PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error
// DeleteWorkerTask deletes worker task with given ID from database.
DeleteWorkerTaskByID(ctx context.Context, id uint) error
}

View File

@ -34,8 +34,8 @@ const (
// queued tasks from being lost. It is simply a // queued tasks from being lost. It is simply a
// means to store a blob of serialized task data. // means to store a blob of serialized task data.
type WorkerTask struct { type WorkerTask struct {
ID uint `bun:""` ID uint `bun:",pk,autoincrement"`
WorkerType uint8 `bun:""` WorkerType WorkerType `bun:",notnull"`
TaskData []byte `bun:""` TaskData []byte `bun:",nullzero,notnull"`
CreatedAt time.Time `bun:""` CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`
} }

View File

@ -197,7 +197,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {
// If the fast-fail flag was set, just // If the fast-fail flag was set, just
// attempt a single iteration instead of // attempt a single iteration instead of
// following the below retry-backoff loop. // following the below retry-backoff loop.
rsp, _, err = c.DoOnce(&req) rsp, _, err = c.DoOnce(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w (fast fail)", err) return nil, fmt.Errorf("%w (fast fail)", err)
} }
@ -208,7 +208,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {
var retry bool var retry bool
// Perform the http request. // Perform the http request.
rsp, retry, err = c.DoOnce(&req) rsp, retry, err = c.DoOnce(req)
if err == nil { if err == nil {
return rsp, nil return rsp, nil
} }

View File

@ -47,8 +47,8 @@ type Request struct {
// WrapRequest wraps an existing http.Request within // WrapRequest wraps an existing http.Request within
// our own httpclient.Request with retry / backoff tracking. // our own httpclient.Request with retry / backoff tracking.
func WrapRequest(r *http.Request) Request { func WrapRequest(r *http.Request) *Request {
var rr Request rr := new(Request)
rr.Request = r rr.Request = r
entry := log.WithContext(r.Context()) entry := log.WithContext(r.Context())
entry = entry.WithField("method", r.Method) entry = entry.WithField("method", r.Method)

View File

@ -352,7 +352,7 @@ func resolveAPObject(data map[string]interface{}) (interface{}, error) {
// we then need to wrangle back into the original type. So we also store the type name // we then need to wrangle back into the original type. So we also store the type name
// and use this to determine the appropriate Go structure type to unmarshal into to. // and use this to determine the appropriate Go structure type to unmarshal into to.
func resolveGTSModel(typ string, data []byte) (interface{}, error) { func resolveGTSModel(typ string, data []byte) (interface{}, error) {
if typ == "" && data == nil { if typ == "" {
// No data given. // No data given.
return nil, nil return nil, nil
} }

View File

@ -0,0 +1,426 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"context"
"fmt"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
)
// NOTE:
// Having these functions in the processor, which is
// usually the intermediary that performs *processing*
// between the HTTP route handlers and the underlying
// database / storage layers is a little odd, so this
// may be subject to change!
//
// For now at least, this is a useful place that has
// access to the underlying database, workers and
// causes no dependency cycles with this use case!
// FillWorkerQueues recovers all serialized worker tasks from the database
// (if any!), and pushes them to each of their relevant worker queues.
func (p *Processor) FillWorkerQueues(ctx context.Context) error {
log.Info(ctx, "rehydrate!")
// Get all persisted worker tasks from db.
//
// (database returns these as ASCENDING, i.e.
// returned in the order they were inserted).
tasks, err := p.state.DB.GetWorkerTasks(ctx)
if err != nil {
return gtserror.Newf("error fetching worker tasks from db: %w", err)
}
var (
// Counts of each task type
// successfully recovered.
delivery int
federator int
client int
// Failed recoveries.
errors int
)
loop:
// Handle each persisted task, removing
// all those we can't handle. Leaving us
// with a slice of tasks we can safely
// delete from being persisted in the DB.
for i := 0; i < len(tasks); {
var err error
// Task at index.
task := tasks[i]
// Appropriate task count
// pointer to increment.
var counter *int
// Attempt to recovery persisted
// task depending on worker type.
switch task.WorkerType {
case gtsmodel.DeliveryWorker:
err = p.pushDelivery(ctx, task)
counter = &delivery
case gtsmodel.FederatorWorker:
err = p.pushFederator(ctx, task)
counter = &federator
case gtsmodel.ClientWorker:
err = p.pushClient(ctx, task)
counter = &client
default:
err = fmt.Errorf("invalid worker type %d", task.WorkerType)
}
if err != nil {
log.Errorf(ctx, "error pushing task %d: %v", task.ID, err)
// Drop error'd task from slice.
tasks = slices.Delete(tasks, i, i+1)
// Incr errors.
errors++
continue loop
}
// Increment slice
// index & counter.
(*counter)++
i++
}
// Tasks that worker successfully pushed
// to their appropriate workers, we can
// safely now remove from the database.
for _, task := range tasks {
if err := p.state.DB.DeleteWorkerTaskByID(ctx, task.ID); err != nil {
log.Errorf(ctx, "error deleting task from db: %v", err)
}
}
// Log recovered tasks.
log.WithContext(ctx).
WithField("delivery", delivery).
WithField("federator", federator).
WithField("client", client).
WithField("errors", errors).
Info("recovered queued tasks")
return nil
}
// PersistWorkerQueues pops all queued worker tasks (that are themselves persistable, i.e. not
// dereference tasks which are just function ptrs), serializes and persists them to the database.
func (p *Processor) PersistWorkerQueues(ctx context.Context) error {
log.Info(ctx, "dehydrate!")
var (
// Counts of each task type
// successfully persisted.
delivery int
federator int
client int
// Failed persists.
errors int
// Serialized tasks to persist.
tasks []*gtsmodel.WorkerTask
)
for {
// Pop all queued deliveries.
task, err := p.popDelivery()
if err != nil {
log.Errorf(ctx, "error popping delivery: %v", err)
errors++ // incr error count.
continue
}
if task == nil {
// No more queue
// tasks to pop!
break
}
// Append serialized task.
tasks = append(tasks, task)
delivery++ // incr count
}
for {
// Pop queued federator msgs.
task, err := p.popFederator()
if err != nil {
log.Errorf(ctx, "error popping federator message: %v", err)
errors++ // incr count
continue
}
if task == nil {
// No more queue
// tasks to pop!
break
}
// Append serialized task.
tasks = append(tasks, task)
federator++ // incr count
}
for {
// Pop queued client msgs.
task, err := p.popClient()
if err != nil {
log.Errorf(ctx, "error popping client message: %v", err)
continue
}
if task == nil {
// No more queue
// tasks to pop!
break
}
// Append serialized task.
tasks = append(tasks, task)
client++ // incr count
}
// Persist all serialized queued worker tasks to database.
if err := p.state.DB.PutWorkerTasks(ctx, tasks); err != nil {
return gtserror.Newf("error putting tasks in db: %w", err)
}
// Log recovered tasks.
log.WithContext(ctx).
WithField("delivery", delivery).
WithField("federator", federator).
WithField("client", client).
WithField("errors", errors).
Info("persisted queued tasks")
return nil
}
// pushDelivery parses a valid delivery.Delivery{} from serialized task data and pushes to queue.
func (p *Processor) pushDelivery(ctx context.Context, task *gtsmodel.WorkerTask) error {
dlv := new(delivery.Delivery)
// Deserialize the raw worker task data into delivery.
if err := dlv.Deserialize(task.TaskData); err != nil {
return gtserror.Newf("error deserializing delivery: %w", err)
}
var tsport transport.Transport
if uri := dlv.ActorID; uri != "" {
// Fetch the actor account by provided URI from db.
account, err := p.state.DB.GetAccountByURI(ctx, uri)
if err != nil {
return gtserror.Newf("error getting actor account %s from db: %w", uri, err)
}
// Fetch a transport for request signing for actor's account username.
tsport, err = p.transport.NewTransportForUsername(ctx, account.Username)
if err != nil {
return gtserror.Newf("error getting transport for actor %s: %w", uri, err)
}
} else {
var err error
// No actor was given, will be signed by instance account.
tsport, err = p.transport.NewTransportForUsername(ctx, "")
if err != nil {
return gtserror.Newf("error getting instance account transport: %w", err)
}
}
// Using transport, add actor signature to delivery.
if err := tsport.SignDelivery(dlv); err != nil {
return gtserror.Newf("error signing delivery: %w", err)
}
// Push deserialized task to delivery queue.
p.state.Workers.Delivery.Queue.Push(dlv)
return nil
}
// popDelivery pops delivery.Delivery{} from queue and serializes as valid task data.
func (p *Processor) popDelivery() (*gtsmodel.WorkerTask, error) {
// Pop waiting delivery from the delivery worker.
delivery, ok := p.state.Workers.Delivery.Queue.Pop()
if !ok {
return nil, nil
}
// Serialize the delivery task data.
data, err := delivery.Serialize()
if err != nil {
return nil, gtserror.Newf("error serializing delivery: %w", err)
}
return &gtsmodel.WorkerTask{
// ID is autoincrement
WorkerType: gtsmodel.DeliveryWorker,
TaskData: data,
CreatedAt: time.Now(),
}, nil
}
// pushClient parses a valid messages.FromFediAPI{} from serialized task data and pushes to queue.
func (p *Processor) pushFederator(ctx context.Context, task *gtsmodel.WorkerTask) error {
var msg messages.FromFediAPI
// Deserialize the raw worker task data into message.
if err := msg.Deserialize(task.TaskData); err != nil {
return gtserror.Newf("error deserializing federator message: %w", err)
}
if rcv := msg.Receiving; rcv != nil {
// Only a placeholder receiving account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, rcv.ID)
if err != nil {
return gtserror.Newf("error fetching receiving account %s from db: %w", rcv.ID, err)
}
// Set the now populated
// receiving account model.
msg.Receiving = account
}
if req := msg.Requesting; req != nil {
// Only a placeholder requesting account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, req.ID)
if err != nil {
return gtserror.Newf("error fetching requesting account %s from db: %w", req.ID, err)
}
// Set the now populated
// requesting account model.
msg.Requesting = account
}
// Push populated task to the federator queue.
p.state.Workers.Federator.Queue.Push(&msg)
return nil
}
// popFederator pops messages.FromFediAPI{} from queue and serializes as valid task data.
func (p *Processor) popFederator() (*gtsmodel.WorkerTask, error) {
// Pop waiting message from the federator worker.
msg, ok := p.state.Workers.Federator.Queue.Pop()
if !ok {
return nil, nil
}
// Serialize message task data.
data, err := msg.Serialize()
if err != nil {
return nil, gtserror.Newf("error serializing federator message: %w", err)
}
return &gtsmodel.WorkerTask{
// ID is autoincrement
WorkerType: gtsmodel.FederatorWorker,
TaskData: data,
CreatedAt: time.Now(),
}, nil
}
// pushClient parses a valid messages.FromClientAPI{} from serialized task data and pushes to queue.
func (p *Processor) pushClient(ctx context.Context, task *gtsmodel.WorkerTask) error {
var msg messages.FromClientAPI
// Deserialize the raw worker task data into message.
if err := msg.Deserialize(task.TaskData); err != nil {
return gtserror.Newf("error deserializing client message: %w", err)
}
if org := msg.Origin; org != nil {
// Only a placeholder origin account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, org.ID)
if err != nil {
return gtserror.Newf("error fetching origin account %s from db: %w", org.ID, err)
}
// Set the now populated
// origin account model.
msg.Origin = account
}
if trg := msg.Target; trg != nil {
// Only a placeholder target account will be populated,
// fetch the actual model from database by persisted ID.
account, err := p.state.DB.GetAccountByID(ctx, trg.ID)
if err != nil {
return gtserror.Newf("error fetching target account %s from db: %w", trg.ID, err)
}
// Set the now populated
// target account model.
msg.Target = account
}
// Push populated task to the federator queue.
p.state.Workers.Client.Queue.Push(&msg)
return nil
}
// popClient pops messages.FromClientAPI{} from queue and serializes as valid task data.
func (p *Processor) popClient() (*gtsmodel.WorkerTask, error) {
// Pop waiting message from the client worker.
msg, ok := p.state.Workers.Client.Queue.Pop()
if !ok {
return nil, nil
}
// Serialize message task data.
data, err := msg.Serialize()
if err != nil {
return nil, gtserror.Newf("error serializing client message: %w", err)
}
return &gtsmodel.WorkerTask{
// ID is autoincrement
WorkerType: gtsmodel.ClientWorker,
TaskData: data,
CreatedAt: time.Now(),
}, nil
}

View File

@ -0,0 +1,421 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
"github.com/superseriousbusiness/gotosocial/testrig"
)
var (
// TODO: move these test values into
// the testrig test models area. They'll
// need to be as both WorkerTask and as
// the raw types themselves.
testDeliveries = []*delivery.Delivery{
{
ObjectID: "https://google.com/users/bigboy/follow/1",
TargetID: "https://askjeeves.com/users/smallboy",
Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Host": {"https://askjeeves.com"}}),
},
{
Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), http.Header{"Host": {"https://google.com"}}),
},
}
testFederatorMsgs = []*messages.FromFediAPI{
{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
TargetURI: "https://gotosocial.org",
Requesting: &gtsmodel.Account{ID: "654321"},
Receiving: &gtsmodel.Account{ID: "123456"},
},
{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate,
TargetURI: "https://uk-queen-is-dead.org",
Requesting: &gtsmodel.Account{ID: "123456"},
Receiving: &gtsmodel.Account{ID: "654321"},
},
}
testClientMsgs = []*messages.FromClientAPI{
{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
TargetURI: "https://gotosocial.org",
Origin: &gtsmodel.Account{ID: "654321"},
Target: &gtsmodel.Account{ID: "123456"},
},
{
APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate,
TargetURI: "https://uk-queen-is-dead.org",
Origin: &gtsmodel.Account{ID: "123456"},
Target: &gtsmodel.Account{ID: "654321"},
},
}
)
type WorkerTaskTestSuite struct {
AdminStandardTestSuite
}
func (suite *WorkerTaskTestSuite) TestFillWorkerQueues() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
var tasks []*gtsmodel.WorkerTask
for _, dlv := range testDeliveries {
// Serialize all test deliveries.
data, err := dlv.Serialize()
if err != nil {
panic(err)
}
// Append each serialized delivery to tasks.
tasks = append(tasks, &gtsmodel.WorkerTask{
WorkerType: gtsmodel.DeliveryWorker,
TaskData: data,
})
}
for _, msg := range testFederatorMsgs {
// Serialize all test messages.
data, err := msg.Serialize()
if err != nil {
panic(err)
}
if msg.Receiving != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Receiving)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Receiving.ID,
})
}
if msg.Requesting != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Requesting)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Requesting.ID,
})
}
// Append each serialized message to tasks.
tasks = append(tasks, &gtsmodel.WorkerTask{
WorkerType: gtsmodel.FederatorWorker,
TaskData: data,
})
}
for _, msg := range testClientMsgs {
// Serialize all test messages.
data, err := msg.Serialize()
if err != nil {
panic(err)
}
if msg.Origin != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Origin)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Origin.ID,
})
}
if msg.Target != nil {
// Quick hack to bypass database errors for non-existing
// accounts, instead we just insert this into cache ;).
suite.state.Caches.DB.Account.Put(msg.Target)
suite.state.Caches.DB.AccountSettings.Put(&gtsmodel.AccountSettings{
AccountID: msg.Target.ID,
})
}
// Append each serialized message to tasks.
tasks = append(tasks, &gtsmodel.WorkerTask{
WorkerType: gtsmodel.ClientWorker,
TaskData: data,
})
}
// Persist all test worker tasks to the database.
err := suite.state.DB.PutWorkerTasks(ctx, tasks)
suite.NoError(err)
// Fill the worker queues from persisted task data.
err = suite.adminProcessor.FillWorkerQueues(ctx)
suite.NoError(err)
var (
// Recovered
// task counts.
ndelivery int
nfederator int
nclient int
)
// Fetch current gotosocial instance account, for later checks.
instanceAcc, err := suite.state.DB.GetInstanceAccount(ctx, "")
suite.NoError(err)
for {
// Pop all queued delivery tasks from worker queue.
dlv, ok := suite.state.Workers.Delivery.Queue.Pop()
if !ok {
break
}
// Incr count.
ndelivery++
// Check that we have this message in slice.
err = containsSerializable(testDeliveries, dlv)
suite.NoError(err)
// Check that delivery request context has instance account pubkey.
pubKeyID := gtscontext.OutgoingPublicKeyID(dlv.Request.Context())
suite.Equal(instanceAcc.PublicKeyURI, pubKeyID)
signfn := gtscontext.HTTPClientSignFunc(dlv.Request.Context())
suite.NotNil(signfn)
}
for {
// Pop all queued federator messages from worker queue.
msg, ok := suite.state.Workers.Federator.Queue.Pop()
if !ok {
break
}
// Incr count.
nfederator++
// Check that we have this message in slice.
err = containsSerializable(testFederatorMsgs, msg)
suite.NoError(err)
}
for {
// Pop all queued client messages from worker queue.
msg, ok := suite.state.Workers.Client.Queue.Pop()
if !ok {
break
}
// Incr count.
nclient++
// Check that we have this message in slice.
err = containsSerializable(testClientMsgs, msg)
suite.NoError(err)
}
// Ensure recovered task counts as expected.
suite.Equal(len(testDeliveries), ndelivery)
suite.Equal(len(testFederatorMsgs), nfederator)
suite.Equal(len(testClientMsgs), nclient)
}
func (suite *WorkerTaskTestSuite) TestPersistWorkerQueues() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Push all test worker tasks to their respective queues.
suite.state.Workers.Delivery.Queue.Push(testDeliveries...)
suite.state.Workers.Federator.Queue.Push(testFederatorMsgs...)
suite.state.Workers.Client.Queue.Push(testClientMsgs...)
// Persist the worker queued tasks to database.
err := suite.adminProcessor.PersistWorkerQueues(ctx)
suite.NoError(err)
// Fetch all the persisted tasks from database.
tasks, err := suite.state.DB.GetWorkerTasks(ctx)
suite.NoError(err)
var (
// Persisted
// task counts.
ndelivery int
nfederator int
nclient int
)
// Check persisted task data.
for _, task := range tasks {
switch task.WorkerType {
case gtsmodel.DeliveryWorker:
var dlv delivery.Delivery
// Incr count.
ndelivery++
// Deserialize the persisted task data.
err := dlv.Deserialize(task.TaskData)
suite.NoError(err)
// Check that we have this delivery in slice.
err = containsSerializable(testDeliveries, &dlv)
suite.NoError(err)
case gtsmodel.FederatorWorker:
var msg messages.FromFediAPI
// Incr count.
nfederator++
// Deserialize the persisted task data.
err := msg.Deserialize(task.TaskData)
suite.NoError(err)
// Check that we have this message in slice.
err = containsSerializable(testFederatorMsgs, &msg)
suite.NoError(err)
case gtsmodel.ClientWorker:
var msg messages.FromClientAPI
// Incr count.
nclient++
// Deserialize the persisted task data.
err := msg.Deserialize(task.TaskData)
suite.NoError(err)
// Check that we have this message in slice.
err = containsSerializable(testClientMsgs, &msg)
suite.NoError(err)
default:
suite.T().Errorf("unexpected worker type: %d", task.WorkerType)
}
}
// Ensure persisted task counts as expected.
suite.Equal(len(testDeliveries), ndelivery)
suite.Equal(len(testFederatorMsgs), nfederator)
suite.Equal(len(testClientMsgs), nclient)
}
func (suite *WorkerTaskTestSuite) SetupTest() {
suite.AdminStandardTestSuite.SetupTest()
// we don't want workers running
testrig.StopWorkers(&suite.state)
}
func TestWorkerTaskTestSuite(t *testing.T) {
suite.Run(t, new(WorkerTaskTestSuite))
}
// containsSerializeable returns whether slice of serializables contains given serializable entry.
func containsSerializable[T interface{ Serialize() ([]byte, error) }](expect []T, have T) error {
// Serialize wanted value.
bh, err := have.Serialize()
if err != nil {
panic(err)
}
var strings []string
for _, t := range expect {
// Serialize expected value.
be, err := t.Serialize()
if err != nil {
panic(err)
}
// Alloc as string.
se := string(be)
if se == string(bh) {
// We have this entry!
return nil
}
// Add to serialized strings.
strings = append(strings, se)
}
return fmt.Errorf("could not find %s in %s", string(bh), strings)
}
// urlStr simply returns u.String() or "" if nil.
func urlStr(u *url.URL) string {
if u == nil {
return ""
}
return u.String()
}
// accountID simply returns account.ID or "" if nil.
func accountID(account *gtsmodel.Account) string {
if account == nil {
return ""
}
return account.ID
}
// toRequest creates httpclient.Request from HTTP method, URL and body data.
func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
var rbody io.Reader
if body != nil {
rbody = bytes.NewReader(body)
}
req, err := http.NewRequest(method, url, rbody)
if err != nil {
panic(err)
}
for key, values := range hdr {
for _, value := range values {
req.Header.Add(key, value)
}
}
return httpclient.WrapRequest(req)
}
// toJSON marshals input type as JSON data.
func toJSON(a any) []byte {
b, err := json.Marshal(a)
if err != nil {
panic(err)
}
return b
}

View File

@ -21,6 +21,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"net/url" "net/url"
@ -169,6 +170,38 @@ func (t *transport) prepare(
}, nil }, nil
} }
func (t *transport) SignDelivery(dlv *delivery.Delivery) error {
if dlv.Request.GetBody == nil {
return gtserror.New("delivery request body not rewindable")
}
// Get a new copy of the request body.
body, err := dlv.Request.GetBody()
if err != nil {
return gtserror.Newf("error getting request body: %w", err)
}
// Read body data into memory.
data, err := io.ReadAll(body)
if err != nil {
return gtserror.Newf("error reading request body: %w", err)
}
// Get signing function for POST data.
// (note that delivery is ALWAYS POST).
sign := t.signPOST(data)
// Extract delivery context.
ctx := dlv.Request.Context()
// Update delivery request context with signing details.
ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID)
ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return nil
}
// getObjectID extracts an object ID from 'serialized' ActivityPub object map. // getObjectID extracts an object ID from 'serialized' ActivityPub object map.
func getObjectID(obj map[string]interface{}) string { func getObjectID(obj map[string]interface{}) string {
switch t := obj["object"].(type) { switch t := obj["object"].(type) {

View File

@ -33,10 +33,6 @@ import (
// be indexed (and so, dropped from queue) // be indexed (and so, dropped from queue)
// by any of these possible ID IRIs. // by any of these possible ID IRIs.
type Delivery struct { type Delivery struct {
// PubKeyID is the signing public key
// ID of the actor performing request.
PubKeyID string
// ActorID contains the ActivityPub // ActorID contains the ActivityPub
// actor ID IRI (if any) of the activity // actor ID IRI (if any) of the activity
// being sent out by this request. // being sent out by this request.
@ -55,7 +51,7 @@ type Delivery struct {
// Request is the prepared (+ wrapped) // Request is the prepared (+ wrapped)
// httpclient.Client{} request that // httpclient.Client{} request that
// constitutes this ActivtyPub delivery. // constitutes this ActivtyPub delivery.
Request httpclient.Request Request *httpclient.Request
// internal fields. // internal fields.
next time.Time next time.Time
@ -66,7 +62,6 @@ type Delivery struct {
// a json serialize / deserialize // a json serialize / deserialize
// able shape that minimizes data. // able shape that minimizes data.
type delivery struct { type delivery struct {
PubKeyID string `json:"pub_key_id,omitempty"`
ActorID string `json:"actor_id,omitempty"` ActorID string `json:"actor_id,omitempty"`
ObjectID string `json:"object_id,omitempty"` ObjectID string `json:"object_id,omitempty"`
TargetID string `json:"target_id,omitempty"` TargetID string `json:"target_id,omitempty"`
@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) {
// Marshal as internal JSON type. // Marshal as internal JSON type.
return json.Marshal(delivery{ return json.Marshal(delivery{
PubKeyID: dlv.PubKeyID,
ActorID: dlv.ActorID, ActorID: dlv.ActorID,
ObjectID: dlv.ObjectID, ObjectID: dlv.ObjectID,
TargetID: dlv.TargetID, TargetID: dlv.TargetID,
@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error {
} }
// Copy over simplest fields. // Copy over simplest fields.
dlv.PubKeyID = idlv.PubKeyID
dlv.ActorID = idlv.ActorID dlv.ActorID = idlv.ActorID
dlv.ObjectID = idlv.ObjectID dlv.ObjectID = idlv.ObjectID
dlv.TargetID = idlv.TargetID dlv.TargetID = idlv.TargetID
@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error {
return err return err
} }
// Copy over any stored header values.
for key, values := range idlv.Header {
for _, value := range values {
r.Header.Add(key, value)
}
}
// Wrap request in httpclient type. // Wrap request in httpclient type.
dlv.Request = httpclient.WrapRequest(r) dlv.Request = httpclient.WrapRequest(r)

View File

@ -35,32 +35,30 @@ var deliveryCases = []struct {
}{ }{
{ {
msg: delivery.Delivery{ msg: delivery.Delivery{
PubKeyID: "https://google.com/users/bigboy#pubkey",
ActorID: "https://google.com/users/bigboy", ActorID: "https://google.com/users/bigboy",
ObjectID: "https://google.com/users/bigboy/follow/1", ObjectID: "https://google.com/users/bigboy/follow/1",
TargetID: "https://askjeeves.com/users/smallboy", TargetID: "https://askjeeves.com/users/smallboy",
Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")), Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}),
}, },
data: toJSON(map[string]any{ data: toJSON(map[string]any{
"pub_key_id": "https://google.com/users/bigboy#pubkey", "actor_id": "https://google.com/users/bigboy",
"actor_id": "https://google.com/users/bigboy", "object_id": "https://google.com/users/bigboy/follow/1",
"object_id": "https://google.com/users/bigboy/follow/1", "target_id": "https://askjeeves.com/users/smallboy",
"target_id": "https://askjeeves.com/users/smallboy", "method": "POST",
"method": "POST", "url": "https://askjeeves.com/users/smallboy/inbox",
"url": "https://askjeeves.com/users/smallboy/inbox", "body": []byte("data!"),
"body": []byte("data!"), "header": map[string][]string{"Hello": {"world1", "world2"}},
// "header": map[string][]string{},
}), }),
}, },
{ {
msg: delivery.Delivery{ msg: delivery.Delivery{
Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")), Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil),
}, },
data: toJSON(map[string]any{ data: toJSON(map[string]any{
"method": "GET", "method": "GET",
"url": "https://google.com", "url": "https://google.com",
"body": []byte("uwu im just a wittle seawch engwin"), "body": []byte("uwu im just a wittle seawch engwin"),
// "header": map[string][]string{}, // "header": map[string][]string{},
}), }),
}, },
} }
@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) {
} }
// Check that delivery fields are as expected. // Check that delivery fields are as expected.
assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID)
assert.Equal(t, test.msg.ActorID, msg.ActorID) assert.Equal(t, test.msg.ActorID, msg.ActorID)
assert.Equal(t, test.msg.ObjectID, msg.ObjectID) assert.Equal(t, test.msg.ObjectID, msg.ObjectID)
assert.Equal(t, test.msg.TargetID, msg.TargetID) assert.Equal(t, test.msg.TargetID, msg.TargetID)
assert.Equal(t, test.msg.Request.Method, msg.Request.Method) assert.Equal(t, test.msg.Request.Method, msg.Request.Method)
assert.Equal(t, test.msg.Request.URL, msg.Request.URL) assert.Equal(t, test.msg.Request.URL, msg.Request.URL)
assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body)) assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body))
assert.Equal(t, test.msg.Request.Header, msg.Request.Header)
} }
} }
// toRequest creates httpclient.Request from HTTP method, URL and body data. // toRequest creates httpclient.Request from HTTP method, URL and body data.
func toRequest(method string, url string, body []byte) httpclient.Request { func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
var rbody io.Reader var rbody io.Reader
if body != nil { if body != nil {
rbody = bytes.NewReader(body) rbody = bytes.NewReader(body)
@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request {
if err != nil { if err != nil {
panic(err) panic(err)
} }
for key, values := range hdr {
for _, value := range values {
req.Header.Add(key, value)
}
}
return httpclient.WrapRequest(req) return httpclient.WrapRequest(req)
} }

View File

@ -19,6 +19,7 @@ package delivery
import ( import (
"context" "context"
"errors"
"slices" "slices"
"time" "time"
@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool {
loop: loop:
for { for {
// Before trying to get
// next delivery, check
// context still valid.
if ctx.Err() != nil {
return true
}
// Get next delivery. // Get next delivery.
dlv, ok := w.next(ctx) dlv, ok := w.next(ctx)
if !ok { if !ok {
@ -195,16 +203,30 @@ loop:
// Attempt delivery of AP request. // Attempt delivery of AP request.
rsp, retry, err := w.Client.DoOnce( rsp, retry, err := w.Client.DoOnce(
&dlv.Request, dlv.Request,
) )
if err == nil { switch {
case err == nil:
// Ensure body closed. // Ensure body closed.
_ = rsp.Body.Close() _ = rsp.Body.Close()
continue loop continue loop
}
if !retry { case errors.Is(err, context.Canceled) &&
ctx.Err() != nil:
// In the case of our own context
// being cancelled, push delivery
// back onto queue for persisting.
//
// Note we specifically check against
// context.Canceled here as it will
// be faster than the mutex lock of
// ctx.Err(), so gives an initial
// faster check in the if-clause.
w.Queue.Push(dlv)
continue loop
case !retry:
// Drop deliveries when no // Drop deliveries when no
// retry requested, or they // retry requested, or they
// reached max (either). // reached max (either).
@ -222,42 +244,36 @@ loop:
// next gets the next available delivery, blocking until available if necessary. // next gets the next available delivery, blocking until available if necessary.
func (w *Worker) next(ctx context.Context) (*Delivery, bool) { func (w *Worker) next(ctx context.Context) (*Delivery, bool) {
loop: // Try a fast-pop of queued
for { // delivery before anything.
// Try pop next queued. dlv, ok := w.Queue.Pop()
dlv, ok := w.Queue.Pop()
if !ok { if !ok {
// Check the backlog. // Check the backlog.
if len(w.backlog) > 0 { if len(w.backlog) > 0 {
// Sort by 'next' time. // Sort by 'next' time.
sortDeliveries(w.backlog) sortDeliveries(w.backlog)
// Pop next delivery. // Pop next delivery.
dlv := w.popBacklog() dlv := w.popBacklog()
return dlv, true return dlv, true
}
select {
// Backlog is empty, we MUST
// block until next enqueued.
case <-w.Queue.Wait():
continue loop
// Worker was stopped.
case <-ctx.Done():
return nil, false
}
} }
// Replace request context for worker state canceling. // Block on next delivery push
ctx := gtscontext.WithValues(ctx, dlv.Request.Context()) // OR worker context canceled.
dlv.Request.Request = dlv.Request.Request.WithContext(ctx) dlv, ok = w.Queue.PopCtx(ctx)
if !ok {
return dlv, true return nil, false
}
} }
// Replace request context for worker state canceling.
ctx = gtscontext.WithValues(ctx, dlv.Request.Context())
dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
return dlv, true
} }
// popBacklog pops next available from the backlog. // popBacklog pops next available from the backlog.

View File

@ -30,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/httpclient" "github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
"github.com/superseriousbusiness/httpsig" "github.com/superseriousbusiness/httpsig"
) )
@ -50,6 +51,10 @@ type Transport interface {
// transport client, retrying on certain preset errors. // transport client, retrying on certain preset errors.
POST(*http.Request, []byte) (*http.Response, error) POST(*http.Request, []byte) (*http.Response, error)
// SignDelivery adds HTTP request signing client "middleware"
// to the request context within given delivery.Delivery{}.
SignDelivery(*delivery.Delivery) error
// Deliver sends an ActivityStreams object. // Deliver sends an ActivityStreams object.
Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error

View File

@ -19,6 +19,7 @@ package workers
import ( import (
"context" "context"
"errors"
"codeberg.org/gruf/go-runners" "codeberg.org/gruf/go-runners"
"codeberg.org/gruf/go-structr" "codeberg.org/gruf/go-structr"
@ -147,9 +148,25 @@ func (w *MsgWorker[T]) process(ctx context.Context) {
return return
} }
// Attempt to process popped message type. // Attempt to process message.
if err := w.Process(ctx, msg); err != nil { err := w.Process(ctx, msg)
if err != nil {
log.Errorf(ctx, "%p: error processing: %v", w, err) log.Errorf(ctx, "%p: error processing: %v", w, err)
if errors.Is(err, context.Canceled) &&
ctx.Err() != nil {
// In the case of our own context
// being cancelled, push message
// back onto queue for persisting.
//
// Note we specifically check against
// context.Canceled here as it will
// be faster than the mutex lock of
// ctx.Err(), so gives an initial
// faster check in the if-clause.
w.Queue.Push(msg)
break
}
} }
} }
} }

View File

@ -55,7 +55,8 @@ type Workers struct {
// StartScheduler starts the job scheduler. // StartScheduler starts the job scheduler.
func (w *Workers) StartScheduler() { func (w *Workers) StartScheduler() {
_ = w.Scheduler.Start() // false = already running _ = w.Scheduler.Start()
// false = already running
log.Info(nil, "started scheduler") log.Info(nil, "started scheduler")
} }
@ -82,9 +83,12 @@ func (w *Workers) Start() {
log.Infof(nil, "started %d dereference workers", n) log.Infof(nil, "started %d dereference workers", n)
} }
// Stop will stop all of the contained worker pools (and global scheduler). // Stop will stop all of the contained
// worker pools (and global scheduler).
func (w *Workers) Stop() { func (w *Workers) Stop() {
_ = w.Scheduler.Stop() // false = not running _ = w.Scheduler.Stop()
// false = not running
log.Info(nil, "stopped scheduler")
w.Delivery.Stop() w.Delivery.Stop()
log.Info(nil, "stopped delivery workers") log.Info(nil, "stopped delivery workers")

View File

@ -29,6 +29,8 @@ import (
var testModels = []interface{}{ var testModels = []interface{}{
&gtsmodel.Account{}, &gtsmodel.Account{},
&gtsmodel.AccountNote{},
&gtsmodel.AccountSettings{},
&gtsmodel.AccountToEmoji{}, &gtsmodel.AccountToEmoji{},
&gtsmodel.Application{}, &gtsmodel.Application{},
&gtsmodel.Block{}, &gtsmodel.Block{},
@ -67,8 +69,7 @@ var testModels = []interface{}{
&gtsmodel.Tombstone{}, &gtsmodel.Tombstone{},
&gtsmodel.Report{}, &gtsmodel.Report{},
&gtsmodel.Rule{}, &gtsmodel.Rule{},
&gtsmodel.AccountNote{}, &gtsmodel.WorkerTask{},
&gtsmodel.AccountSettings{},
} }
// NewTestDB returns a new initialized, empty database for testing. // NewTestDB returns a new initialized, empty database for testing.