mirror of
1
Fork 0

[bugfix] return early in websocket upgrade handler (#1315)

* launch websocket streaming in goroutine to allow upgrade handler to return

* don't send any message on ping, improved close check on failed read

* use context to signal wsconn close, ensure canceled in read goroutine

Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
kim 2023-01-08 11:43:08 +00:00 committed by GitHub
parent 98edd75f1b
commit 1bda6a2002
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 110 additions and 70 deletions

View File

@ -19,6 +19,8 @@
package api package api
import ( import (
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin" "github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
@ -122,7 +124,7 @@ func NewClient(db db.DB, p processing.Processor) *Client {
notifications: notifications.New(p), notifications: notifications.New(p),
search: search.New(p), search: search.New(p),
statuses: statuses.New(p), statuses: statuses.New(p),
streaming: streaming.New(p), streaming: streaming.New(p, time.Second*30, 4096),
timelines: timelines.New(p), timelines: timelines.New(p),
user: user.New(p), user: user.New(p),
} }

View File

@ -19,8 +19,9 @@
package streaming package streaming
import ( import (
"context"
"errors"
"fmt" "fmt"
"net/http"
"time" "time"
"codeberg.org/gruf/go-kv" "codeberg.org/gruf/go-kv"
@ -32,16 +33,6 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
var (
wsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// we expect cors requests (via eg., pinafore.social) so be lenient
CheckOrigin: func(r *http.Request) bool { return true },
}
errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader)
)
// StreamGETHandler swagger:operation GET /api/v1/streaming streamGet // StreamGETHandler swagger:operation GET /api/v1/streaming streamGet
// //
// Initiate a websocket connection for live streaming of statuses and notifications. // Initiate a websocket connection for live streaming of statuses and notifications.
@ -150,21 +141,20 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
return return
} }
var accessToken string var token string
if t := c.Query(AccessTokenQueryKey); t != "" {
// try query param first // First we check for a query param provided access token
accessToken = t if token = c.Query(AccessTokenQueryKey); token == "" {
} else if t := c.GetHeader(AccessTokenHeader); t != "" { // Else we check the HTTP header provided token
// fall back to Sec-Websocket-Protocol if token = c.GetHeader(AccessTokenHeader); token == "" {
accessToken = t const errStr = "no access token provided"
} else { err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr)
// no token apiutil.ErrorHandler(c, err, m.processor.InstanceGet)
err := errNoToken
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }
}
account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken) account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), token)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return return
@ -178,51 +168,97 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
l := log.WithFields(kv.Fields{ l := log.WithFields(kv.Fields{
{"account", account.Username}, {"account", account.Username},
{"path", BasePath},
{"streamID", stream.ID}, {"streamID", stream.ID},
{"streamType", streamType}, {"streamType", streamType},
}...) }...)
wsConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) // Upgrade the incoming HTTP request, which hijacks the underlying
// connection and reuses it for the websocket (non-http) protocol.
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
// If the upgrade fails, then Upgrade replies to the client with an HTTP error response. l.Errorf("error upgrading websocket connection: %v", err)
// Because websocket issues are a pretty common source of headaches, we should also log
// this at Error to make this plenty visible and help admins out a bit.
l.Errorf("error upgrading websocket connection: %s", err)
close(stream.Hangup) close(stream.Hangup)
return return
} }
go func() {
// We perform the main websocket send loop in a separate
// goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
l.Info("opened websocket connection")
defer l.Info("closed websocket connection")
// Create new context for lifetime of the connection
ctx, cncl := context.WithCancel(context.Background())
// Create ticker to send alive pings
pinger := time.NewTicker(m.dTicker)
defer func() { defer func() {
// cleanup // Signal done
wsConn.Close() cncl()
// Close websocket conn
_ = wsConn.Close()
// Close processor stream
close(stream.Hangup) close(stream.Hangup)
// Stop ping ticker
pinger.Stop()
}() }()
streamTicker := time.NewTicker(m.tickDuration) go func() {
defer streamTicker.Stop() // Signal done
defer cncl()
// We want to stay in the loop as long as possible while the client is connected. for {
// The only thing that should break the loop is if the client leaves or the connection becomes unhealthy. // We have to listen for received websocket messages in
// order to trigger the underlying wsConn.PingHandler().
// //
// If the loop does break, we expect the client to reattempt connection, so it's cheap to leave + try again // So we wait on received messages but only act on errors.
wsLoop: _, _, err := wsConn.ReadMessage()
if err != nil {
if ctx.Err() == nil {
// Only log error if the connection was not closed
// by us. Uncanceled context indicates this is the case.
l.Errorf("error reading from websocket: %v", err)
}
return
}
}
}()
for { for {
select { select {
case m := <-stream.Messages: // Connection closed
l.Trace("received message from stream") case <-ctx.Done():
if err := wsConn.WriteJSON(m); err != nil { return
l.Debugf("error writing json to websocket connection; breaking off: %s", err)
break wsLoop // Received next stream message
case msg := <-stream.Messages:
l.Tracef("sending message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Errorf("error writing json to websocket: %v", err)
return
} }
l.Trace("wrote message into websocket connection")
case <-streamTicker.C: // Reset on each successful send.
l.Trace("received TICK from ticker") pinger.Reset(m.dTicker)
if err := wsConn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil {
l.Debugf("error writing ping to websocket connection; breaking off: %s", err) // Send keep-alive "ping"
break wsLoop case <-pinger.C:
} l.Trace("pinging websocket ...")
l.Trace("wrote ping message into websocket connection") if err := wsConn.WriteMessage(
websocket.PingMessage,
[]byte{},
); err != nil {
l.Errorf("error writing ping to websocket: %v", err)
return
} }
} }
} }
}()
}

View File

@ -23,6 +23,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
) )
@ -42,20 +43,21 @@ const (
type Module struct { type Module struct {
processor processing.Processor processor processing.Processor
tickDuration time.Duration dTicker time.Duration
wsUpgrade websocket.Upgrader
} }
func New(processor processing.Processor) *Module { func New(processor processing.Processor, dTicker time.Duration, wsBuf int) *Module {
return &Module{ return &Module{
processor: processor, processor: processor,
tickDuration: 30 * time.Second, dTicker: dTicker,
} wsUpgrade: websocket.Upgrader{
} ReadBufferSize: wsBuf, // we don't expect reads
WriteBufferSize: wsBuf,
func NewWithTickDuration(processor processing.Processor, tickDuration time.Duration) *Module { // we expect cors requests (via eg., pinafore.social) so be lenient
return &Module{ CheckOrigin: func(r *http.Request) bool { return true },
processor: processor, },
tickDuration: tickDuration,
} }
} }

View File

@ -99,7 +99,7 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.streamingModule = streaming.NewWithTickDuration(suite.processor, 1) suite.streamingModule = streaming.New(suite.processor, 1, 4096)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
} }