From 6a29c5ffd40f1919cac40030c53160c19812bc8d Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Fri, 28 Apr 2023 16:45:21 +0100 Subject: [PATCH] [performance] improved request batching (removes need for queueing) (#1687) * revamp http client to not limit requests, instead use sender worker Signed-off-by: kim * remove separate sender worker pool, spawn 2*GOMAXPROCS batch senders each time, no need for transport cache sweeping Signed-off-by: kim * improve batch senders to keep popping recipients until remote URL found Signed-off-by: kim * fix recipient looping issue Signed-off-by: kim * fix missing mutex unlock Signed-off-by: kim * move request id ctx key to gtscontext, finish filling out more code comments, add basic support for not logging client IP Signed-off-by: kim * slight code reformatting Signed-off-by: kim * a whitespace Signed-off-by: kim * remove unused code Signed-off-by: kim * add missing license headers Signed-off-by: kim * fix request backoff calculation Signed-off-by: kim --------- Signed-off-by: kim --- go.mod | 2 +- internal/api/util/errorhandling.go | 6 +- internal/federation/authenticate.go | 4 +- internal/federation/federatingprotocol.go | 9 +- internal/gtscontext/context.go | 52 +++- internal/gtscontext/log_hooks.go | 44 +++ internal/httpclient/client.go | 264 ++++++++++++------ internal/httpclient/client_test.go | 8 - internal/httpclient/request.go | 62 ---- .../context_test.go => httpclient/sign.go} | 19 +- internal/middleware/logger.go | 19 +- internal/middleware/requestid.go | 27 +- internal/processing/account/get.go | 4 +- internal/processing/fedi/common.go | 4 +- internal/processing/fedi/user.go | 4 +- internal/processing/media/getfile.go | 4 +- internal/processing/search.go | 10 +- internal/processing/util.go | 4 +- internal/transport/context.go | 42 --- internal/transport/controller.go | 24 +- internal/transport/deliver.go | 111 +++++--- internal/transport/transport.go | 187 ++----------- internal/workers/workers.go | 6 +- testrig/transportcontroller.go | 8 +- 24 files changed, 431 insertions(+), 493 deletions(-) create mode 100644 internal/gtscontext/log_hooks.go delete mode 100644 internal/httpclient/request.go rename internal/{transport/context_test.go => httpclient/sign.go} (72%) delete mode 100644 internal/transport/context.go diff --git a/go.mod b/go.mod index 557398263..5709f2e96 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/abema/go-mp4 v0.10.1 github.com/buckket/go-blurhash v1.1.0 github.com/coreos/go-oidc/v3 v3.5.0 - github.com/cornelk/hashmap v1.0.8 github.com/disintegration/imaging v1.6.2 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 @@ -82,6 +81,7 @@ require ( github.com/cilium/ebpf v0.9.1 // indirect github.com/containerd/cgroups/v3 v3.0.1 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/cornelk/hashmap v1.0.8 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/go-units v0.4.0 // indirect github.com/dsoprea/go-exif/v3 v3.0.0-20210625224831-a6301f85c82b // indirect diff --git a/internal/api/util/errorhandling.go b/internal/api/util/errorhandling.go index 45bcf1d7a..4daaf44c8 100644 --- a/internal/api/util/errorhandling.go +++ b/internal/api/util/errorhandling.go @@ -24,9 +24,9 @@ import ( "codeberg.org/gruf/go-kv" "github.com/gin-gonic/gin" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/middleware" ) // TODO: add more templated html pages here for different error types @@ -51,7 +51,7 @@ func NotFoundHandler(c *gin.Context, instanceGet func(ctx context.Context) (*api c.HTML(http.StatusNotFound, "404.tmpl", gin.H{ "instance": instance, - "requestID": middleware.RequestID(ctx), + "requestID": gtscontext.RequestID(ctx), }) default: c.JSON(http.StatusNotFound, gin.H{ @@ -76,7 +76,7 @@ func genericErrorHandler(c *gin.Context, instanceGet func(ctx context.Context) ( "instance": instance, "code": errWithCode.Code(), "error": errWithCode.Safe(), - "requestID": middleware.RequestID(ctx), + "requestID": gtscontext.RequestID(ctx), }) default: c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/federation/authenticate.go b/internal/federation/authenticate.go index 96436ee0e..5fe4873d4 100644 --- a/internal/federation/authenticate.go +++ b/internal/federation/authenticate.go @@ -34,10 +34,10 @@ import ( "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/transport" ) /* @@ -216,7 +216,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU } log.Tracef(ctx, "proceeding with dereference for uncached public key %s", requestingPublicKeyID) - trans, err := f.transportController.NewTransportForUsername(transport.WithFastfail(ctx), requestedUsername) + trans, err := f.transportController.NewTransportForUsername(gtscontext.SetFastFail(ctx), requestedUsername) if err != nil { errWithCode := gtserror.NewErrorInternalError(fmt.Errorf("error creating transport for %s: %s", requestedUsername, err)) log.Debug(ctx, errWithCode) diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index 52f46586d..7995faa84 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -29,10 +29,10 @@ import ( "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/uris" "github.com/superseriousbusiness/gotosocial/internal/util" ) @@ -191,9 +191,8 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr return ctx, false, err } - // We don't yet have an entry for - // the instance, go dereference it. - instance, err := f.GetRemoteInstance(transport.WithFastfail(ctx), username, &url.URL{ + // we don't have an entry for this instance yet so dereference it + instance, err := f.GetRemoteInstance(gtscontext.SetFastFail(ctx), username, &url.URL{ Scheme: publicKeyOwnerURI.Scheme, Host: publicKeyOwnerURI.Host, }) @@ -212,7 +211,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr // dereference the remote account (or just get it // from the db if we already have it). requestingAccount, err := f.GetAccountByURI( - transport.WithFastfail(ctx), username, publicKeyOwnerURI, false, + gtscontext.SetFastFail(ctx), username, publicKeyOwnerURI, false, ) if err != nil { if gtserror.StatusCode(err) == http.StatusGone { diff --git a/internal/gtscontext/context.go b/internal/gtscontext/context.go index 7d4a44774..d52bf2801 100644 --- a/internal/gtscontext/context.go +++ b/internal/gtscontext/context.go @@ -17,7 +17,9 @@ package gtscontext -import "context" +import ( + "context" +) // package private context key type. type ctxkey uint @@ -26,8 +28,54 @@ const ( // context keys. _ ctxkey = iota barebonesKey + fastFailKey + pubKeyIDKey + requestIDKey ) +// RequestID returns the request ID associated with context. This value will usually +// be set by the request ID middleware handler, either pulling an existing supplied +// value from request headers, or generating a unique new entry. This is useful for +// tying together log entries associated with an original incoming request. +func RequestID(ctx context.Context) string { + id, _ := ctx.Value(requestIDKey).(string) + return id +} + +// SetRequestID stores the given request ID value and returns the wrapped +// context. See RequestID() for further information on the request ID value. +func SetRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, requestIDKey, id) +} + +// PublicKeyID returns the public key ID (URI) associated with context. This +// value is useful for logging situations in which a given public key URI is +// relevant, e.g. for outgoing requests being signed by the given key. +func PublicKeyID(ctx context.Context) string { + id, _ := ctx.Value(pubKeyIDKey).(string) + return id +} + +// SetPublicKeyID stores the given public key ID value and returns the wrapped +// context. See PublicKeyID() for further information on the public key ID value. +func SetPublicKeyID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, pubKeyIDKey, id) +} + +// IsFastFail returns whether the "fastfail" context key has been set. This +// can be used to indicate to an http client, for example, that the result +// of an outgoing request is time sensitive and so not to bother with retries. +func IsFastfail(ctx context.Context) bool { + _, ok := ctx.Value(fastFailKey).(struct{}) + return ok +} + +// SetFastFail sets the "fastfail" context flag and returns this wrapped context. +// See IsFastFail() for further information on the "fastfail" context flag. +func SetFastFail(ctx context.Context) context.Context { + return context.WithValue(ctx, fastFailKey, struct{}{}) +} + // Barebones returns whether the "barebones" context key has been set. This // can be used to indicate to the database, for example, that only a barebones // model need be returned, Allowing it to skip populating sub models. @@ -37,7 +85,7 @@ func Barebones(ctx context.Context) bool { } // SetBarebones sets the "barebones" context flag and returns this wrapped context. -// See Barebones() for further information on the "barebones" context flag.. +// See Barebones() for further information on the "barebones" context flag. func SetBarebones(ctx context.Context) context.Context { return context.WithValue(ctx, barebonesKey, struct{}{}) } diff --git a/internal/gtscontext/log_hooks.go b/internal/gtscontext/log_hooks.go new file mode 100644 index 000000000..2fe43e488 --- /dev/null +++ b/internal/gtscontext/log_hooks.go @@ -0,0 +1,44 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtscontext + +import ( + "context" + + "codeberg.org/gruf/go-kv" + "github.com/superseriousbusiness/gotosocial/internal/log" +) + +func init() { + // Add our required logging hooks on application initialization. + // + // Request ID middleware hook. + log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { + if id := RequestID(ctx); id != "" { + return append(kvs, kv.Field{K: "requestID", V: id}) + } + return kvs + }) + // Client IP middleware hook. + log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { + if id := PublicKeyID(ctx); id != "" { + return append(kvs, kv.Field{K: "pubKeyID", V: id}) + } + return kvs + }) +} diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index 9562bdc48..67a1d0715 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -18,31 +18,39 @@ package httpclient import ( + "context" + "crypto/x509" "errors" + "fmt" "io" "net" "net/http" "net/netip" "runtime" + "strconv" + "strings" "time" "codeberg.org/gruf/go-bytesize" + "codeberg.org/gruf/go-byteutil" + "codeberg.org/gruf/go-cache/v3" + errorsv2 "codeberg.org/gruf/go-errors/v2" "codeberg.org/gruf/go-kv" - "github.com/cornelk/hashmap" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/log" ) -// ErrInvalidRequest is returned if a given HTTP request is invalid and cannot be performed. -var ErrInvalidRequest = errors.New("invalid http request") +var ( + // ErrInvalidNetwork is returned if the request would not be performed over TCP + ErrInvalidNetwork = errors.New("invalid network type") -// ErrInvalidNetwork is returned if the request would not be performed over TCP -var ErrInvalidNetwork = errors.New("invalid network type") + // ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. + ErrReservedAddr = errors.New("dial within blocked / reserved IP range") -// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. -var ErrReservedAddr = errors.New("dial within blocked / reserved IP range") - -// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). -var ErrBodyTooLarge = errors.New("body size too large") + // ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). + ErrBodyTooLarge = errors.New("body size too large") +) // Config provides configuration details for setting up a new // instance of httpclient.Client{}. Within are a subset of the @@ -83,13 +91,10 @@ type Config struct { // cases to protect against forged / unknown content-lengths // - protection from server side request forgery (SSRF) by only dialing // out to known public IP prefixes, configurable with allows/blocks -// - limit number of concurrent requests, else blocking until a slot -// is available (context channels still respected) type Client struct { - client http.Client - queue *hashmap.Map[string, chan struct{}] - bmax int64 // max response body size - cmax int // max open conns per host + client http.Client + badHosts cache.Cache[string, struct{}] + bodyMax int64 } // New returns a new instance of Client initialized using configuration. @@ -109,28 +114,26 @@ func New(cfg Config) *Client { } if cfg.MaxIdleConns <= 0 { - // By default base this value on MaxOpenConns + // By default base this value on MaxOpenConns. cfg.MaxIdleConns = cfg.MaxOpenConnsPerHost * 10 } if cfg.MaxBodySize <= 0 { - // By default set this to a reasonable 40MB + // By default set this to a reasonable 40MB. cfg.MaxBodySize = int64(40 * bytesize.MiB) } - // Protect dialer with IP range sanitizer + // Protect dialer with IP range sanitizer. d.Control = (&sanitizer{ allow: cfg.AllowRanges, block: cfg.BlockRanges, }).Sanitize - // Prepare client fields + // Prepare client fields. c.client.Timeout = cfg.Timeout - c.cmax = cfg.MaxOpenConnsPerHost - c.bmax = cfg.MaxBodySize - c.queue = hashmap.New[string, chan struct{}]() + c.bodyMax = cfg.MaxBodySize - // Set underlying HTTP client roundtripper + // Set underlying HTTP client roundtripper. c.client.Transport = &http.Transport{ Proxy: http.ProxyFromEnvironment, ForceAttemptHTTP2: true, @@ -144,90 +147,185 @@ func New(cfg Config) *Client { DisableCompression: cfg.DisableCompression, } + // Initiate outgoing bad hosts lookup cache. + c.badHosts = cache.New[string, struct{}](0, 1000, 0) + c.badHosts.SetTTL(15*time.Minute, false) + if !c.badHosts.Start(time.Minute) { + log.Panic(nil, "failed to start transport controller cache") + } + return &c } -// Do will perform given request when an available slot in the queue is available, -// and block until this time. For returned values, this follows the same semantics -// as the standard http.Client{}.Do() implementation except that response body will -// be wrapped by an io.LimitReader() to limit response body sizes. -func (c *Client) Do(req *http.Request) (*http.Response, error) { - // Ensure this is a valid request - if err := ValidateRequest(req); err != nil { - return nil, err +// Do ... +func (c *Client) Do(r *http.Request) (*http.Response, error) { + return c.DoSigned(r, func(r *http.Request) error { + return nil // no request signing + }) +} + +// DoSigned ... +func (c *Client) DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) { + const ( + // max no. attempts. + maxRetries = 5 + + // starting backoff duration. + baseBackoff = 2 * time.Second + ) + + // Get request hostname. + host := r.URL.Hostname() + + // Check whether request should fast fail. + fastFail := gtscontext.IsFastfail(r.Context()) + if !fastFail { + // Check if recently reached max retries for this host + // so we don't bother with a retry-backoff loop. The only + // errors that are retried upon are server failure and + // domain resolution type errors, so this cached result + // indicates this server is likely having issues. + fastFail = c.badHosts.Has(host) } - // Get host's wait queue - wait := c.wait(req.Host) + // Start a log entry for this request + l := log.WithContext(r.Context()). + WithFields(kv.Fields{ + {"method", r.Method}, + {"url", r.URL.String()}, + }...) - var ok bool + for i := 0; i < maxRetries; i++ { + var backoff time.Duration - select { - // Quickly try grab a spot - case wait <- struct{}{}: - // it's our turn! - ok = true + // Reset signing header fields + now := time.Now().UTC() + r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") + r.Header.Del("Signature") + r.Header.Del("Digest") - // NOTE: - // Ideally here we would set the slot release to happen either - // on error return, or via callback from the response body closer. - // However when implementing this, there appear deadlocks between - // the channel queue here and the media manager worker pool. So - // currently we only place a limit on connections dialing out, but - // there may still be more connections open than len(c.queue) given - // that connections may not be closed until response body is closed. - // The current implementation will reduce the viability of denial of - // service attacks, but if there are future issues heed this advice :] - defer func() { <-wait }() - default: - } + // Rewind body reader and content-length if set. + if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { + r.ContentLength = int64(rc.Len()) + rc.Rewind() + } - if !ok { - // No spot acquired, log warning - log.WithContext(req.Context()). - WithFields(kv.Fields{ - {K: "queue", V: len(wait)}, - {K: "method", V: req.Method}, - {K: "host", V: req.Host}, - {K: "uri", V: req.URL.RequestURI()}, - }...).Warn("full request queue") + // Sign the outgoing request. + if err := sign(r); err != nil { + return nil, err + } + + l.Infof("performing request") + + // Perform the request. + rsp, err := c.do(r) + if err == nil { //nolint:gocritic + + // TooManyRequest means we need to slow + // down and retry our request. Codes over + // 500 generally indicate temp. outages. + if code := rsp.StatusCode; code < 500 && + code != http.StatusTooManyRequests { + return rsp, nil + } + + // Generate error from status code for logging + err = errors.New(`http response "` + rsp.Status + `"`) + + // Search for a provided "Retry-After" header value. + if after := rsp.Header.Get("Retry-After"); after != "" { + + if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { + // An integer number of backoff seconds was provided. + backoff = time.Duration(u) * time.Second + } else if at, _ := http.ParseTime(after); !at.Before(now) { + // An HTTP formatted future date-time was provided. + backoff = at.Sub(now) + } + + // Don't let their provided backoff exceed our max. + if max := baseBackoff * maxRetries; backoff > max { + backoff = max + } + } + + } else if errorsv2.Is(err, + context.DeadlineExceeded, + context.Canceled, + ErrBodyTooLarge, + ErrReservedAddr, + ) { + // Return on non-retryable errors + return nil, err + } else if strings.Contains(err.Error(), "stopped after 10 redirects") { + // Don't bother if net/http returned after too many redirects + return nil, err + } else if errors.As(err, &x509.UnknownAuthorityError{}) { + // Unknown authority errors we do NOT recover from + return nil, err + } else if dnserr := (*net.DNSError)(nil); // nocollapse + errors.As(err, &dnserr) && dnserr.IsNotFound { + // DNS lookup failure, this domain does not exist + return nil, gtserror.SetNotFound(err) + } + + if fastFail { + // on fast-fail, don't bother backoff/retry + return nil, fmt.Errorf("%w (fast fail)", err) + } + + if backoff == 0 { + // No retry-after found, set our predefined + // backoff according to a multiplier of 2^n. + backoff = baseBackoff * 1 << (i + 1) + } + + l.Errorf("backing off for %s after http request error: %v", backoff, err) select { - case <-req.Context().Done(): - // the request was canceled before we - // got to our turn: no need to release - return nil, req.Context().Err() - case wait <- struct{}{}: - defer func() { <-wait }() + // Request ctx cancelled + case <-r.Context().Done(): + return nil, r.Context().Err() + + // Backoff for some time + case <-time.After(backoff): } } - // Perform the HTTP request + // Add "bad" entry for this host. + c.badHosts.Set(host, struct{}{}) + + return nil, errors.New("transport reached max retries") +} + +// do ... +func (c *Client) do(req *http.Request) (*http.Response, error) { + // Perform the HTTP request. rsp, err := c.client.Do(req) if err != nil { return nil, err } - // Check response body not too large - if rsp.ContentLength > c.bmax { + // Check response body not too large. + if rsp.ContentLength > c.bodyMax { return nil, ErrBodyTooLarge } - // Seperate the body implementers + // Seperate the body implementers. rbody := (io.Reader)(rsp.Body) cbody := (io.Closer)(rsp.Body) var limit int64 if limit = rsp.ContentLength; limit < 0 { - // If unknown, use max as reader limit - limit = c.bmax + // If unknown, use max as reader limit. + limit = c.bodyMax } - // Don't trust them, limit body reads + // Don't trust them, limit body reads. rbody = io.LimitReader(rbody, limit) - // Wrap body with limit + // Wrap body with limit. rsp.Body = &struct { io.Reader io.Closer @@ -235,17 +333,3 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { return rsp, nil } - -// wait acquires the 'wait' queue for the given host string, or allocates new. -func (c *Client) wait(host string) chan struct{} { - // Look for an existing queue - queue, ok := c.queue.Get(host) - if ok { - return queue - } - - // Allocate a new host queue (or return a sneaky existing one). - queue, _ = c.queue.GetOrInsert(host, make(chan struct{}, c.cmax)) - - return queue -} diff --git a/internal/httpclient/client_test.go b/internal/httpclient/client_test.go index 9eab0fed4..f0ec01ec3 100644 --- a/internal/httpclient/client_test.go +++ b/internal/httpclient/client_test.go @@ -48,14 +48,6 @@ var bodies = []string{ "body with\r\nnewlines", } -// Note: -// There is no test for the .MaxOpenConns implementation -// in the httpclient.Client{}, due to the difficult to test -// this. The block is only held for the actual dial out to -// the connection, so the usual test of blocking and holding -// open this queue slot to check we can't open another isn't -// an easy test here. - func TestHTTPClientSmallBody(t *testing.T) { for _, body := range bodies { _TestHTTPClientWithBody(t, []byte(body), int(^uint16(0))) diff --git a/internal/httpclient/request.go b/internal/httpclient/request.go deleted file mode 100644 index 881d3f699..000000000 --- a/internal/httpclient/request.go +++ /dev/null @@ -1,62 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package httpclient - -import ( - "fmt" - "net/http" - "strings" - - "golang.org/x/net/http/httpguts" -) - -// ValidateRequest performs the same request validation logic found in the default -// net/http.Transport{}.roundTrip() function, but pulls it out into this separate -// function allowing validation errors to be wrapped under a single error type. -func ValidateRequest(r *http.Request) error { - switch { - case r.URL == nil: - return fmt.Errorf("%w: nil url", ErrInvalidRequest) - case r.Header == nil: - return fmt.Errorf("%w: nil header", ErrInvalidRequest) - case r.URL.Host == "": - return fmt.Errorf("%w: empty url host", ErrInvalidRequest) - case r.URL.Scheme != "http" && r.URL.Scheme != "https": - return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme) - case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1: - return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method) - } - - for key, values := range r.Header { - // Check field key name is valid - if !httpguts.ValidHeaderFieldName(key) { - return fmt.Errorf("%w: invalid header field name %q", ErrInvalidRequest, key) - } - - // Check each field value is valid - for i := 0; i < len(values); i++ { - if !httpguts.ValidHeaderFieldValue(values[i]) { - return fmt.Errorf("%w: invalid header field value %q", ErrInvalidRequest, values[i]) - } - } - } - - // ps. kim wrote this - - return nil -} diff --git a/internal/transport/context_test.go b/internal/httpclient/sign.go similarity index 72% rename from internal/transport/context_test.go rename to internal/httpclient/sign.go index e06e7c4d5..78046aa28 100644 --- a/internal/transport/context_test.go +++ b/internal/httpclient/sign.go @@ -15,19 +15,14 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package transport_test +package httpclient -import ( - "context" - "testing" +import "net/http" - "github.com/superseriousbusiness/gotosocial/internal/transport" -) +// SignFunc is a function signature that provides request signing. +type SignFunc func(r *http.Request) error -func TestFastFailContext(t *testing.T) { - ctx := context.Background() - ctx = transport.WithFastfail(ctx) - if !transport.IsFastfail(ctx) { - t.Fatal("failed to set fast-fail context key") - } +type SigningClient interface { + Do(r *http.Request) (*http.Response, error) + DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) } diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go index e80488330..50e5542c3 100644 --- a/internal/middleware/logger.go +++ b/internal/middleware/logger.go @@ -34,7 +34,7 @@ import ( func Logger() gin.HandlerFunc { return func(c *gin.Context) { // Initialize the logging fields - fields := make(kv.Fields, 6, 7) + fields := make(kv.Fields, 5, 7) // Determine pre-handler time before := time.Now() @@ -68,11 +68,18 @@ func Logger() gin.HandlerFunc { // Set request logging fields fields[0] = kv.Field{"latency", time.Since(before)} - fields[1] = kv.Field{"clientIP", c.ClientIP()} - fields[2] = kv.Field{"userAgent", c.Request.UserAgent()} - fields[3] = kv.Field{"method", c.Request.Method} - fields[4] = kv.Field{"statusCode", code} - fields[5] = kv.Field{"path", path} + fields[1] = kv.Field{"userAgent", c.Request.UserAgent()} + fields[2] = kv.Field{"method", c.Request.Method} + fields[3] = kv.Field{"statusCode", code} + fields[4] = kv.Field{"path", path} + if includeClientIP := true; includeClientIP { + // TODO: make this configurable. + // + // Include clientIP if enabled. + fields = append(fields, kv.Field{ + "clientIP", c.ClientIP(), + }) + } // Create log entry with fields l := log.WithContext(c.Request.Context()). diff --git a/internal/middleware/requestid.go b/internal/middleware/requestid.go index 27189b219..6e2a83c68 100644 --- a/internal/middleware/requestid.go +++ b/internal/middleware/requestid.go @@ -19,7 +19,6 @@ package middleware import ( "bufio" - "context" "crypto/rand" "encoding/base32" "encoding/binary" @@ -27,17 +26,11 @@ import ( "sync" "time" - "codeberg.org/gruf/go-kv" "github.com/gin-gonic/gin" - "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" ) -type ctxType string - var ( - // ridCtxKey is the key underwhich we store request IDs in a context. - ridCtxKey ctxType = "id" - // crand provides buffered reads of random input. crand = bufio.NewReader(rand.Reader) mrand sync.Mutex @@ -69,22 +62,8 @@ func generateID() string { return base32enc.EncodeToString(b) } -// RequestID fetches the stored request ID from context. -func RequestID(ctx context.Context) string { - id, _ := ctx.Value(ridCtxKey).(string) - return id -} - // AddRequestID returns a gin middleware which adds a unique ID to each request (both response header and context). func AddRequestID(header string) gin.HandlerFunc { - log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { - if id, _ := ctx.Value(ridCtxKey).(string); id != "" { - // Add stored request ID to log entry fields. - return append(kvs, kv.Field{K: "requestID", V: id}) - } - return kvs - }) - return func(c *gin.Context) { // Look for existing ID. id := c.GetHeader(header) @@ -100,8 +79,8 @@ func AddRequestID(header string) gin.HandlerFunc { c.Request.Header.Set(header, id) } - // Store request ID in new request ctx and set new gin request obj. - ctx := context.WithValue(c.Request.Context(), ridCtxKey, id) + // Store request ID in new request context and set on gin ctx. + ctx := gtscontext.SetRequestID(c.Request.Context(), id) c.Request = c.Request.WithContext(ctx) // Set the request ID in the rsp header. diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index 84d00c46b..d0ea96ca2 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -25,9 +25,9 @@ import ( apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/transport" ) // Get processes the given request for account information. @@ -96,7 +96,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco } a, err := p.federator.GetAccountByURI( - transport.WithFastfail(ctx), requestingAccount.Username, targetAccountURI, true, + gtscontext.SetFastFail(ctx), requestingAccount.Username, targetAccountURI, true, ) if err == nil { targetAccount = a diff --git a/internal/processing/fedi/common.go b/internal/processing/fedi/common.go index 91b3030e1..3fade397b 100644 --- a/internal/processing/fedi/common.go +++ b/internal/processing/fedi/common.go @@ -22,9 +22,9 @@ import ( "fmt" "net/url" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/transport" ) func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { @@ -40,7 +40,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string) return } - if requestingAccount, err = p.federator.GetAccountByURI(transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false); err != nil { + if requestingAccount, err = p.federator.GetAccountByURI(gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false); err != nil { errWithCode = gtserror.NewErrorUnauthorized(err) return } diff --git a/internal/processing/fedi/user.go b/internal/processing/fedi/user.go index 3343ae8bc..28dc3c857 100644 --- a/internal/processing/fedi/user.go +++ b/internal/processing/fedi/user.go @@ -24,8 +24,8 @@ import ( "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/uris" ) @@ -56,7 +56,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque // if we're not already handshaking/dereferencing a remote account, dereference it now if !p.federator.Handshaking(requestedUsername, requestingAccountURI) { requestingAccount, err := p.federator.GetAccountByURI( - transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false, + gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false, ) if err != nil { return nil, gtserror.NewErrorUnauthorized(err) diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index 293093ac2..2694fde13 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -25,10 +25,10 @@ import ( "strings" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/uris" ) @@ -157,7 +157,7 @@ func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount if err != nil { return nil, 0, err } - return t.DereferenceMedia(transport.WithFastfail(innerCtx), remoteMediaIRI) + return t.DereferenceMedia(gtscontext.SetFastFail(innerCtx), remoteMediaIRI) } // Start recaching this media with the prepared data function. diff --git a/internal/processing/search.go b/internal/processing/search.go index 0c9ef43fd..624537b6a 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -30,11 +30,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/oauth" - "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/util" ) @@ -226,14 +226,14 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a } func (p *Processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL) (*gtsmodel.Status, error) { - status, statusable, err := p.federator.GetStatus(transport.WithFastfail(ctx), authed.Account.Username, uri, true, true) + status, statusable, err := p.federator.GetStatus(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, true, true) if err != nil { return nil, err } if !*status.Local && statusable != nil { // Attempt to dereference the status thread while we are here - p.federator.DereferenceThread(transport.WithFastfail(ctx), authed.Account.Username, uri, status, statusable) + p.federator.DereferenceThread(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, status, statusable) } return status, nil @@ -268,7 +268,7 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, } return p.federator.GetAccountByURI( - transport.WithFastfail(ctx), + gtscontext.SetFastFail(ctx), authed.Account.Username, uri, false, ) @@ -295,7 +295,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o } return p.federator.GetAccountByUsernameDomain( - transport.WithFastfail(ctx), + gtscontext.SetFastFail(ctx), authed.Account.Username, username, domain, false, ) diff --git a/internal/processing/util.go b/internal/processing/util.go index 3f3f7ec79..967c03f9f 100644 --- a/internal/processing/util.go +++ b/internal/processing/util.go @@ -24,9 +24,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/util" ) @@ -58,7 +58,7 @@ func GetParseMentionFunc(dbConn db.DB, federator federation.Federator) gtsmodel. } remoteAccount, err := federator.GetAccountByUsernameDomain( - transport.WithFastfail(ctx), + gtscontext.SetFastFail(ctx), requestingUsername, username, domain, diff --git a/internal/transport/context.go b/internal/transport/context.go deleted file mode 100644 index 96d3f23f7..000000000 --- a/internal/transport/context.go +++ /dev/null @@ -1,42 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package transport - -import "context" - -// ctxkey is our own unique context key type to prevent setting outside package. -type ctxkey string - -// fastfailkey is our unique context key to indicate fast-fail is enabled. -var fastfailkey = ctxkey("ff") - -// WithFastfail returns a Context which indicates that any http requests made -// with it should return after the first failed attempt, instead of retrying. -// -// This can be used to fail quickly when you're making an outgoing http request -// inside the context of an incoming http request, and you want to be able to -// provide a snappy response to the user, instead of retrying + backing off. -func WithFastfail(parent context.Context) context.Context { - return context.WithValue(parent, fastfailkey, struct{}{}) -} - -// IsFastfail returns true if the given context was created by WithFastfail. -func IsFastfail(ctx context.Context) bool { - _, ok := ctx.Value(fastfailkey).(struct{}) - return ok -} diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 331659f64..e1271d202 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -24,7 +24,7 @@ import ( "encoding/json" "fmt" "net/url" - "time" + "runtime" "codeberg.org/gruf/go-byteutil" "codeberg.org/gruf/go-cache/v3" @@ -32,7 +32,7 @@ import ( "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" - "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/httpclient" "github.com/superseriousbusiness/gotosocial/internal/state" ) @@ -49,14 +49,14 @@ type controller struct { state *state.State fedDB federatingdb.DB clock pub.Clock - client pub.HttpClient + client httpclient.SigningClient trspCache cache.Cache[string, *transport] - badHosts cache.Cache[string, struct{}] userAgent string + senders int // no. concurrent batch delivery routines. } // NewController returns an implementation of the Controller interface for creating new transports -func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { +func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client httpclient.SigningClient) Controller { applicationName := config.GetApplicationName() host := config.GetHost() proto := config.GetProtocol() @@ -68,20 +68,8 @@ func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.C clock: clock, client: client, trspCache: cache.New[string, *transport](0, 100, 0), - badHosts: cache.New[string, struct{}](0, 1000, 0), userAgent: fmt.Sprintf("%s (+%s://%s) gotosocial/%s", applicationName, proto, host, version), - } - - // Transport cache has TTL=1hr freq=1min - c.trspCache.SetTTL(time.Hour, false) - if !c.trspCache.Start(time.Minute) { - log.Panic(nil, "failed to start transport controller cache") - } - - // Bad hosts cache has TTL=15min freq=1min - c.badHosts.SetTTL(15*time.Minute, false) - if !c.badHosts.Start(time.Minute) { - log.Panic(nil, "failed to start transport controller cache") + senders: runtime.GOMAXPROCS(0), // on batch delivery, only ever send GOMAXPROCS at a time. } return c diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index 8ec939503..fff7dbcf4 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -22,7 +22,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "sync" "codeberg.org/gruf/go-byteutil" @@ -32,54 +31,90 @@ import ( ) func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { - // concurrently deliver to recipients; for each delivery, buffer the error if it fails - wg := sync.WaitGroup{} - errCh := make(chan error, len(recipients)) - for _, recipient := range recipients { - wg.Add(1) - go func(r *url.URL) { - defer wg.Done() - if err := t.Deliver(ctx, b, r); err != nil { - errCh <- err + var ( + // errs accumulates errors received during + // attempted delivery by deliverer routines. + errs gtserror.MultiError + + // wait blocks until all sender + // routines have returned. + wait sync.WaitGroup + + // mutex protects 'recipients' and + // 'errs' for concurrent access. + mutex sync.Mutex + + // Get current instance host info. + domain = config.GetAccountDomain() + host = config.GetHost() + ) + + // Block on expect no. senders. + wait.Add(t.controller.senders) + + for i := 0; i < t.controller.senders; i++ { + go func() { + // Mark returned. + defer wait.Done() + + for { + // Acquire lock. + mutex.Lock() + + if len(recipients) == 0 { + // Reached end. + mutex.Unlock() + return + } + + // Pop next recipient. + i := len(recipients) - 1 + to := recipients[i] + recipients = recipients[:i] + + // Done with lock. + mutex.Unlock() + + // Skip delivery to recipient if it is "us". + if to.Host == host || to.Host == domain { + continue + } + + // Attempt to deliver data to recipient. + if err := t.deliver(ctx, b, to); err != nil { + mutex.Lock() // safely append err to accumulator. + errs.Appendf("error delivering to %s: %v", to, err) + mutex.Unlock() + } } - }(recipient) + }() } - // wait until all deliveries have succeeded or failed - wg.Wait() + // Wait for finish. + wait.Wait() - // receive any buffered errors - errs := make([]string, 0, len(errCh)) -outer: - for { - select { - case e := <-errCh: - errs = append(errs, e.Error()) - default: - break outer - } - } - - if len(errs) > 0 { - return fmt.Errorf("BatchDeliver: at least one failure: %s", strings.Join(errs, "; ")) - } - - return nil + // Return combined err. + return errs.Combine() } func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { - // if the 'to' host is our own, just skip this delivery since we by definition already have the message! + // if 'to' host is our own, skip as we don't need to deliver to ourselves... if to.Host == config.GetHost() || to.Host == config.GetAccountDomain() { return nil } - urlStr := to.String() + // Deliver data to recipient. + return t.deliver(ctx, b, to) +} + +func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error { + url := to.String() // Use rewindable bytes reader for body. var body byteutil.ReadNopCloser body.Reset(b) - req, err := http.NewRequestWithContext(ctx, "POST", urlStr, &body) + req, err := http.NewRequestWithContext(ctx, "POST", url, &body) if err != nil { return err } @@ -88,16 +123,16 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { req.Header.Add("Accept-Charset", "utf-8") req.Header.Set("Host", to.Host) - resp, err := t.POST(req, b) + rsp, err := t.POST(req, b) if err != nil { return err } - defer resp.Body.Close() + defer rsp.Body.Close() - if code := resp.StatusCode; code != http.StatusOK && + if code := rsp.StatusCode; code != http.StatusOK && code != http.StatusCreated && code != http.StatusAccepted { - err := fmt.Errorf("POST request to %s failed: %s", urlStr, resp.Status) - return gtserror.WithStatusCode(err, resp.StatusCode) + err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status) + return gtserror.WithStatusCode(err, rsp.StatusCode) } return nil diff --git a/internal/transport/transport.go b/internal/transport/transport.go index e8f742f5b..0123b3ea8 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -20,26 +20,17 @@ package transport import ( "context" "crypto" - "crypto/x509" "errors" - "fmt" "io" - "net" "net/http" "net/url" - "strconv" - "strings" "sync" "time" - "codeberg.org/gruf/go-byteutil" - errorsv2 "codeberg.org/gruf/go-errors/v2" - "codeberg.org/gruf/go-kv" "github.com/go-fed/httpsig" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/httpclient" - "github.com/superseriousbusiness/gotosocial/internal/log" ) // Transport implements the pub.Transport interface with some additional functionality for fetching remote media. @@ -78,7 +69,7 @@ type Transport interface { Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) } -// transport implements the Transport interface +// transport implements the Transport interface. type transport struct { controller *controller pubKeyID string @@ -95,9 +86,11 @@ func (t *transport) GET(r *http.Request) (*http.Response, error) { if r.Method != http.MethodGet { return nil, errors.New("must be GET request") } - return t.do(r, func(r *http.Request) error { - return t.signGET(r) - }) + ctx := r.Context() // extract, set pubkey ID. + ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) + r = r.WithContext(ctx) // replace request ctx. + r.Header.Set("User-Agent", t.controller.userAgent) + return t.controller.client.DoSigned(r, t.signGET()) } // POST will perform given http request using transport client, retrying on certain preset errors. @@ -105,161 +98,31 @@ func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) { if r.Method != http.MethodPost { return nil, errors.New("must be POST request") } - return t.do(r, func(r *http.Request) error { - return t.signPOST(r, body) - }) -} - -func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http.Response, error) { - const ( - // max no. attempts - maxRetries = 5 - - // starting backoff duration. - baseBackoff = 2 * time.Second - ) - - // Get request hostname - host := r.URL.Hostname() - - // Check whether request should fast fail, we check this - // before loop as each context.Value() requires mutex lock. - fastFail := IsFastfail(r.Context()) - if !fastFail { - // Check if recently reached max retries for this host - // so we don't bother with a retry-backoff loop. The only - // errors that are retried upon are server failure and - // domain resolution type errors, so this cached result - // indicates this server is likely having issues. - fastFail = t.controller.badHosts.Has(host) - } - - // Start a log entry for this request - l := log.WithContext(r.Context()). - WithFields(kv.Fields{ - {"pubKeyID", t.pubKeyID}, - {"method", r.Method}, - {"url", r.URL.String()}, - }...) - + ctx := r.Context() // extract, set pubkey ID. + ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) + r = r.WithContext(ctx) // replace request ctx. r.Header.Set("User-Agent", t.controller.userAgent) - - for i := 0; i < maxRetries; i++ { - var backoff time.Duration - - // Reset signing header fields - now := t.controller.clock.Now().UTC() - r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - r.Header.Del("Signature") - r.Header.Del("Digest") - - // Rewind body reader and content-length if set. - if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { - r.ContentLength = int64(rc.Len()) - rc.Rewind() - } - - // Perform request signing - if err := signer(r); err != nil { - return nil, err - } - - l.Infof("performing request") - - // Attempt to perform request - rsp, err := t.controller.client.Do(r) - if err == nil { //nolint:gocritic - // TooManyRequest means we need to slow - // down and retry our request. Codes over - // 500 generally indicate temp. outages. - if code := rsp.StatusCode; code < 500 && - code != http.StatusTooManyRequests { - return rsp, nil - } - - // Generate error from status code for logging - err = errors.New(`http response "` + rsp.Status + `"`) - - // Search for a provided "Retry-After" header value. - if after := rsp.Header.Get("Retry-After"); after != "" { - - if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { - // An integer number of backoff seconds was provided. - backoff = time.Duration(u) * time.Second - } else if at, _ := http.ParseTime(after); !at.Before(now) { - // An HTTP formatted future date-time was provided. - backoff = at.Sub(now) - } - - // Don't let their provided backoff exceed our max. - if max := baseBackoff * maxRetries; backoff > max { - backoff = max - } - } - - } else if errorsv2.Is(err, - context.DeadlineExceeded, - context.Canceled, - httpclient.ErrInvalidRequest, - httpclient.ErrBodyTooLarge, - httpclient.ErrReservedAddr, - ) { - // Return on non-retryable errors - return nil, err - } else if strings.Contains(err.Error(), "stopped after 10 redirects") { - // Don't bother if net/http returned after too many redirects - return nil, err - } else if errors.As(err, &x509.UnknownAuthorityError{}) { - // Unknown authority errors we do NOT recover from - return nil, err - } else if dnserr := (*net.DNSError)(nil); // nocollapse - errors.As(err, &dnserr) && dnserr.IsNotFound { - // DNS lookup failure, this domain does not exist - return nil, gtserror.SetNotFound(err) - } - - if fastFail { - // on fast-fail, don't bother backoff/retry - return nil, fmt.Errorf("%w (fast fail)", err) - } - - if backoff == 0 { - // No retry-after found, set our predefined backoff. - backoff = time.Duration(i) * baseBackoff - } - - l.Errorf("backing off for %s after http request error: %v", backoff, err) - - select { - // Request ctx cancelled - case <-r.Context().Done(): - return nil, r.Context().Err() - - // Backoff for some time - case <-time.After(backoff): - } - } - - // Add "bad" entry for this host. - t.controller.badHosts.Set(host, struct{}{}) - - return nil, errors.New("transport reached max retries") + return t.controller.client.DoSigned(r, t.signPOST(body)) } // signGET will safely sign an HTTP GET request. -func (t *transport) signGET(r *http.Request) (err error) { - t.safesign(func() { - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) - }) - return +func (t *transport) signGET() httpclient.SignFunc { + return func(r *http.Request) (err error) { + t.safesign(func() { + err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) + }) + return + } } // signPOST will safely sign an HTTP POST request for given body. -func (t *transport) signPOST(r *http.Request, body []byte) (err error) { - t.safesign(func() { - err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) - }) - return +func (t *transport) signPOST(body []byte) httpclient.SignFunc { + return func(r *http.Request) (err error) { + t.safesign(func() { + err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) + }) + return + } } // safesign will perform sign function within mutex protection, diff --git a/internal/workers/workers.go b/internal/workers/workers.go index bf64a28ee..aa8e40e1c 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -31,8 +31,12 @@ type Workers struct { // Main task scheduler instance. Scheduler sched.Scheduler - // ClientAPI / federator worker pools. + // ClientAPI provides a worker pool that handles both + // incoming client actions, and our own side-effects. ClientAPI runners.WorkerPool + + // Federator provides a worker pool that handles both + // incoming federated actions, and our own side-effects. Federator runners.WorkerPool // Enqueue functions for clientAPI / federator worker pools, diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index f2c6b1d28..b74888934 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -26,12 +26,12 @@ import ( "strings" "sync" - "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/httpclient" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/transport" @@ -51,7 +51,7 @@ const ( // Unlike the other test interfaces provided in this package, you'll probably want to call this function // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular) // basis. -func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller { +func NewTestTransportController(state *state.State, client httpclient.SigningClient) transport.Controller { return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client) } @@ -225,6 +225,10 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { return m.do(req) } +func (m *MockHTTPClient) DoSigned(req *http.Request, sign httpclient.SignFunc) (*http.Response, error) { + return m.do(req) +} + func HostMetaResponse(req *http.Request) (responseCode int, responseBytes []byte, responseContentType string, responseContentLength int) { var hm *apimodel.HostMeta