mirror of
1
Fork 0

[bugfix] Add back removed ValidateRequest() before backoff-retry loop (#1805)

* add back removed ValidateRequest() before backoff-retry loop

Signed-off-by: kim <grufwub@gmail.com>

* include response body in error response log

Signed-off-by: kim <grufwub@gmail.com>

* improved error response body draining

Signed-off-by: kim <grufwub@gmail.com>

* add more code commenting

Signed-off-by: kim <grufwub@gmail.com>

* move new error response logic to gtserror, handle instead in transport.Transport{} impl

Signed-off-by: kim <grufwub@gmail.com>

* appease ye oh mighty linter

Signed-off-by: kim <grufwub@gmail.com>

* fix mockhttpclient not setting request in http response

Signed-off-by: kim <grufwub@gmail.com>

---------

Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
kim 2023-05-21 17:59:14 +01:00 committed by GitHub
parent 107237c8e8
commit 2063d01cdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 299 additions and 32 deletions

View File

@ -34,8 +34,8 @@ const (
notFoundKey notFoundKey
errorTypeKey errorTypeKey
// error types // Types returnable from Type(...).
TypeSMTP ErrorType = "smtp" // smtp (mail) error TypeSMTP ErrorType = "smtp" // smtp (mail)
) )
// StatusCode checks error for a stored status code value. For example // StatusCode checks error for a stored status code value. For example

66
internal/gtserror/new.go Normal file
View File

@ -0,0 +1,66 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtserror
import (
"errors"
"net/http"
"codeberg.org/gruf/go-byteutil"
)
// NewResponseError crafts an error from provided HTTP response
// including the method, status and body (if any provided). This
// will also wrap the returned error using WithStatusCode().
func NewResponseError(rsp *http.Response) error {
var buf byteutil.Buffer
// Get URL string ahead of time.
urlStr := rsp.Request.URL.String()
// Alloc guesstimate of required buf size.
buf.Guarantee(0 +
len(rsp.Request.Method) +
12 + // request to
len(urlStr) +
17 + // failed: status="
len(rsp.Status) +
8 + // " body="
256 + // max body size
1, // "
)
// Build error message string without
// using "fmt", as chances are this will
// be used in a hot code path and we
// know all the incoming types involved.
_, _ = buf.WriteString(rsp.Request.Method)
_, _ = buf.WriteString(" request to ")
_, _ = buf.WriteString(urlStr)
_, _ = buf.WriteString(" failed: status=\"")
_, _ = buf.WriteString(rsp.Status)
_, _ = buf.WriteString("\" body=\"")
_, _ = buf.WriteString(drainBody(rsp.Body, 256))
_, _ = buf.WriteString("\"")
// Create new error from msg.
err := errors.New(buf.String())
// Wrap error to provide status code.
return WithStatusCode(err, rsp.StatusCode)
}

View File

@ -0,0 +1,91 @@
package gtserror_test
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
)
func TestResponseError(t *testing.T) {
testResponseError(t, http.Response{
Body: toBody(`{"error": "user not found"}`),
Request: &http.Request{
Method: "GET",
URL: toURL("https://google.com/users/sundar"),
},
Status: "404 Not Found",
})
testResponseError(t, http.Response{
Body: toBody("Unauthorized"),
Request: &http.Request{
Method: "POST",
URL: toURL("https://google.com/inbox"),
},
Status: "401 Unauthorized",
})
testResponseError(t, http.Response{
Body: toBody(""),
Request: &http.Request{
Method: "GET",
URL: toURL("https://google.com/users/sundar"),
},
Status: "404 Not Found",
})
}
func testResponseError(t *testing.T, rsp http.Response) {
var body string
if rsp.Body == http.NoBody {
body = "<empty>"
} else {
var b []byte
rsp.Body, b = copyBody(rsp.Body)
trunc := len(b)
if trunc > 256 {
trunc = 256
}
body = string(b[:trunc])
}
expect := fmt.Sprintf(
"%s request to %s failed: status=\"%s\" body=\"%s\"",
rsp.Request.Method,
rsp.Request.URL.String(),
rsp.Status,
body,
)
err := gtserror.NewResponseError(&rsp)
if str := err.Error(); str != expect {
t.Errorf("unexpected error string: recv=%q expct=%q", str, expect)
}
}
func toURL(u string) *url.URL {
url, err := url.Parse(u)
if err != nil {
panic(err)
}
return url
}
func toBody(s string) io.ReadCloser {
if s == "" {
return http.NoBody
}
r := strings.NewReader(s)
return io.NopCloser(r)
}
func copyBody(rc io.ReadCloser) (io.ReadCloser, []byte) {
b, err := io.ReadAll(rc)
if err != nil {
panic(err)
}
r := bytes.NewReader(b)
return io.NopCloser(r), b
}

42
internal/gtserror/util.go Normal file
View File

@ -0,0 +1,42 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtserror
import (
"io"
"codeberg.org/gruf/go-byteutil"
)
// drainBody will produce a truncated output of the content
// of given io.ReadCloser body, useful for logs / errors.
func drainBody(body io.ReadCloser, trunc int) string {
// Limit response to 'trunc' bytes.
buf := make([]byte, trunc)
// Read body into err buffer.
n, _ := io.ReadFull(body, buf)
if n == 0 {
// No error body, return
// reasonable error str.
return "<empty>"
}
return byteutil.B2S(buf[:n])
}

View File

@ -41,6 +41,9 @@ import (
) )
var ( var (
// ErrInvalidRequest is returned if a given HTTP request is invalid and cannot be performed.
ErrInvalidRequest = errors.New("invalid http request")
// ErrInvalidNetwork is returned if the request would not be performed over TCP // ErrInvalidNetwork is returned if the request would not be performed over TCP
ErrInvalidNetwork = errors.New("invalid network type") ErrInvalidNetwork = errors.New("invalid network type")
@ -90,6 +93,9 @@ type Config struct {
// cases to protect against forged / unknown content-lengths // cases to protect against forged / unknown content-lengths
// - protection from server side request forgery (SSRF) by only dialing // - protection from server side request forgery (SSRF) by only dialing
// out to known public IP prefixes, configurable with allows/blocks // out to known public IP prefixes, configurable with allows/blocks
// - retry-backoff logic for error temporary HTTP error responses
// - optional request signing
// - request logging
type Client struct { type Client struct {
client http.Client client http.Client
badHosts cache.Cache[string, struct{}] badHosts cache.Cache[string, struct{}]
@ -156,14 +162,14 @@ func New(cfg Config) *Client {
return &c return &c
} }
// Do ... // Do will essentially perform http.Client{}.Do() with retry-backoff functionality.
func (c *Client) Do(r *http.Request) (*http.Response, error) { func (c *Client) Do(r *http.Request) (*http.Response, error) {
return c.DoSigned(r, func(r *http.Request) error { return c.DoSigned(r, func(r *http.Request) error {
return nil // no request signing return nil // no request signing
}) })
} }
// DoSigned ... // DoSigned will essentially perform http.Client{}.Do() with retry-backoff functionality and requesting signing..
func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, err error) { func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, err error) {
const ( const (
// max no. attempts. // max no. attempts.
@ -173,6 +179,11 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e
baseBackoff = 2 * time.Second baseBackoff = 2 * time.Second
) )
// First validate incoming request.
if err := ValidateRequest(r); err != nil {
return nil, err
}
// Get request hostname. // Get request hostname.
host := r.URL.Hostname() host := r.URL.Hostname()
@ -234,8 +245,8 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e
return rsp, nil return rsp, nil
} }
// Generate error from status code for logging // Create loggable error from response status code.
err = errors.New(`http response "` + rsp.Status + `"`) err = fmt.Errorf(`http response: %s`, rsp.Status)
// Search for a provided "Retry-After" header value. // Search for a provided "Retry-After" header value.
if after := rsp.Header.Get("Retry-After"); after != "" { if after := rsp.Header.Get("Retry-After"); after != "" {
@ -307,7 +318,7 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e
return return
} }
// do ... // do wraps http.Client{}.Do() to provide safely limited response bodies.
func (c *Client) do(req *http.Request) (*http.Response, error) { func (c *Client) do(req *http.Request) (*http.Response, error) {
// Perform the HTTP request. // Perform the HTTP request.
rsp, err := c.client.Do(req) rsp, err := c.client.Do(req)

View File

@ -0,0 +1,62 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package httpclient
import (
"fmt"
"net/http"
"strings"
"golang.org/x/net/http/httpguts"
)
// ValidateRequest performs the same request validation logic found in the default
// net/http.Transport{}.roundTrip() function, but pulls it out into this separate
// function allowing validation errors to be wrapped under a single error type.
func ValidateRequest(r *http.Request) error {
switch {
case r.URL == nil:
return fmt.Errorf("%w: nil url", ErrInvalidRequest)
case r.Header == nil:
return fmt.Errorf("%w: nil header", ErrInvalidRequest)
case r.URL.Host == "":
return fmt.Errorf("%w: empty url host", ErrInvalidRequest)
case r.URL.Scheme != "http" && r.URL.Scheme != "https":
return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme)
case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1:
return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method)
}
for key, values := range r.Header {
// Check field key name is valid
if !httpguts.ValidHeaderFieldName(key) {
return fmt.Errorf("%w: invalid header field name %q", ErrInvalidRequest, key)
}
// Check each field value is valid
for i := 0; i < len(values); i++ {
if !httpguts.ValidHeaderFieldValue(values[i]) {
return fmt.Errorf("%w: invalid header field value %q", ErrInvalidRequest, values[i])
}
}
}
// ps. kim wrote this
return nil
}

View File

@ -19,7 +19,6 @@ package transport
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"sync" "sync"
@ -131,8 +130,7 @@ func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error {
if code := rsp.StatusCode; code != http.StatusOK && if code := rsp.StatusCode; code != http.StatusOK &&
code != http.StatusCreated && code != http.StatusAccepted { code != http.StatusCreated && code != http.StatusAccepted {
err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status) return gtserror.NewResponseError(rsp)
return gtserror.WithStatusCode(err, rsp.StatusCode)
} }
return nil return nil

View File

@ -19,7 +19,6 @@ package transport
import ( import (
"context" "context"
"fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -66,8 +65,7 @@ func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, erro
defer rsp.Body.Close() defer rsp.Body.Close()
if rsp.StatusCode != http.StatusOK { if rsp.StatusCode != http.StatusOK {
err := fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status) return nil, gtserror.NewResponseError(rsp)
return nil, gtserror.WithStatusCode(err, rsp.StatusCode)
} }
return io.ReadAll(rsp.Body) return io.ReadAll(rsp.Body)

View File

@ -102,8 +102,7 @@ func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL)
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) return nil, gtserror.NewResponseError(resp)
return nil, gtserror.WithStatusCode(err, resp.StatusCode)
} }
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
@ -133,7 +132,7 @@ func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL)
ID: ulid, ID: ulid,
Domain: iri.Host, Domain: iri.Host,
Title: apiResp.Title, Title: apiResp.Title,
URI: fmt.Sprintf("%s://%s", iri.Scheme, iri.Host), URI: iri.Scheme + "://" + iri.Host,
ShortDescription: apiResp.ShortDescription, ShortDescription: apiResp.ShortDescription,
Description: apiResp.Description, Description: apiResp.Description,
ContactEmail: apiResp.Email, ContactEmail: apiResp.Email,
@ -253,8 +252,7 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) return nil, gtserror.NewResponseError(resp)
return nil, gtserror.WithStatusCode(err, resp.StatusCode)
} }
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
@ -305,8 +303,7 @@ func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.No
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) return nil, gtserror.NewResponseError(resp)
return nil, gtserror.WithStatusCode(err, resp.StatusCode)
} }
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)

View File

@ -19,7 +19,6 @@ package transport
import ( import (
"context" "context"
"fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -47,8 +46,7 @@ func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.Read
// Check for an expected status code // Check for an expected status code
if rsp.StatusCode != http.StatusOK { if rsp.StatusCode != http.StatusOK {
err := fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status) return nil, 0, gtserror.NewResponseError(rsp)
return nil, 0, gtserror.WithStatusCode(err, rsp.StatusCode)
} }
return rsp.Body, rsp.ContentLength, nil return rsp.Body, rsp.ContentLength, nil

View File

@ -27,6 +27,7 @@ import (
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
) )
// webfingerURLFor returns the URL to try a webfinger request against, as // webfingerURLFor returns the URL to try a webfinger request against, as
@ -105,14 +106,16 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
// From here on out, we're handling different failure scenarios and // From here on out, we're handling different failure scenarios and
// deciding whether we should do a host-meta based fallback or not // deciding whether we should do a host-meta based fallback or not
if (rsp.StatusCode >= 500 && rsp.StatusCode < 600) || cached { // Response status codes >= 500 are returned as errors by the wrapped HTTP client.
// In case we got a 5xx, bail out irrespective of if the value //
// was cached or not. The target may be broken or be signalling // if (rsp.StatusCode >= 500 && rsp.StatusCode < 600) || cached {
// us to back-off. // In case we got a 5xx, bail out irrespective of if the value
// // was cached or not. The target may be broken or be signalling
// If it's any error but the URL was cached, bail out too // us to back-off.
return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status) //
} // If it's any error but the URL was cached, bail out too
// return nil, gtserror.NewResponseError(rsp)
// }
// So far we've failed to get a successful response from the expected // So far we've failed to get a successful response from the expected
// webfinger endpoint. Lets try and discover the webfinger endpoint // webfinger endpoint. Lets try and discover the webfinger endpoint
@ -153,7 +156,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
} }
// We've reached the end of the line here, both the original request // We've reached the end of the line here, both the original request
// and our attempt to resolve it through the fallback have failed // and our attempt to resolve it through the fallback have failed
return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status) return nil, gtserror.NewResponseError(rsp)
} }
// Set the URL in cache here, since host-meta told us this should be the // Set the URL in cache here, since host-meta told us this should be the

View File

@ -209,6 +209,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
reader := bytes.NewReader(responseBytes) reader := bytes.NewReader(responseBytes)
readCloser := io.NopCloser(reader) readCloser := io.NopCloser(reader)
return &http.Response{ return &http.Response{
Request: req,
StatusCode: responseCode, StatusCode: responseCode,
Body: readCloser, Body: readCloser,
ContentLength: int64(responseContentLength), ContentLength: int64(responseContentLength),