mirror of
1
Fork 0

[bugfix] fix possible mutex lockup during streaming code (#2633)

* rewrite Stream{} to use much less mutex locking, update related code

* use new context for the stream context

* ensure stream gets closed on return of writeTo / readFrom WSConn()

* ensure stream write timeout gets cancelled

* remove embedded context type from Stream{}, reformat log messages for consistency

* use c.Request.Context() for context passed into Stream().Open()

* only return 1 boolean, fix tests to expect multiple stream types in messages

* changes to ping logic

* further improved ping logic

* don't export unused function types, update message sending to only include relevant stream type

* ensure stream gets closed 🤦

* update to error log on failed json marshal (instead of panic)

* inverse websocket read error checking to _ignore_ expected close errors
This commit is contained in:
kim 2024-02-20 18:07:49 +00:00 committed by GitHub
parent 8cafa6b74b
commit 291e180990
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 535 additions and 451 deletions

View File

@ -22,10 +22,10 @@ import (
"slices" "slices"
"time" "time"
"codeberg.org/gruf/go-kv"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
streampkg "github.com/superseriousbusiness/gotosocial/internal/stream" streampkg "github.com/superseriousbusiness/gotosocial/internal/stream"
@ -202,7 +202,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// functions pass messages into a channel, which we can // functions pass messages into a channel, which we can
// then read from and put into a websockets connection. // then read from and put into a websockets connection.
stream, errWithCode := m.processor.Stream().Open( stream, errWithCode := m.processor.Stream().Open(
c.Request.Context(), c.Request.Context(), // this ctx is only used for logging
account, account,
streamType, streamType,
) )
@ -213,10 +213,8 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
l := log. l := log.
WithContext(c.Request.Context()). WithContext(c.Request.Context()).
WithFields(kv.Fields{ WithField("streamID", id.NewULID()).
{"username", account.Username}, WithField("username", account.Username)
{"streamID", stream.ID},
}...)
// Upgrade the incoming HTTP request. This hijacks the // Upgrade the incoming HTTP request. This hijacks the
// underlying connection and reuses it for the websocket // underlying connection and reuses it for the websocket
@ -227,18 +225,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
l.Errorf("error upgrading websocket connection: %v", err) l.Errorf("error upgrading websocket connection: %v", err)
close(stream.Hangup) stream.Close()
return return
} }
l.Info("opened websocket connection")
// We perform the main websocket rw loops in a separate // We perform the main websocket rw loops in a separate
// goroutine in order to let the upgrade handler return. // goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any // This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become // throttle / rate-limit request tokens which could become
// problematic on instances with multiple users. // problematic on instances with multiple users.
go m.handleWSConn(account.Username, wsConn, stream) go m.handleWSConn(&l, wsConn, stream)
} }
// handleWSConn handles a two-way websocket streaming connection. // handleWSConn handles a two-way websocket streaming connection.
@ -246,48 +242,39 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// into the connection. If any errors are encountered while reading // into the connection. If any errors are encountered while reading
// or writing (including expected errors like clients leaving), the // or writing (including expected errors like clients leaving), the
// connection will be closed. // connection will be closed.
func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) { func (m *Module) handleWSConn(l *log.Entry, wsConn *websocket.Conn, stream *streampkg.Stream) {
// Create new context for the lifetime of this connection. l.Info("opened websocket connection")
ctx, cancel := context.WithCancel(context.Background())
l := log. // Create new async context with cancel.
WithContext(ctx). ctx, cncl := context.WithCancel(context.Background())
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
// Create ticker to send keepalive pings
pinger := time.NewTicker(m.dTicker)
// Read messages coming from the Websocket client connection into the server.
go func() { go func() {
defer cancel() defer cncl()
m.readFromWSConn(ctx, username, wsConn, stream)
// Read messages from websocket to server.
m.readFromWSConn(ctx, wsConn, stream, l)
}() }()
// Write messages coming from the processor into the Websocket client connection.
go func() { go func() {
defer cancel() defer cncl()
m.writeToWSConn(ctx, username, wsConn, stream, pinger)
// Write messages from processor in websocket conn.
m.writeToWSConn(ctx, wsConn, stream, m.dTicker, l)
}() }()
// Wait for either the read or write functions to close, to indicate // Wait for ctx
// that the client has left, or something else has gone wrong. // to be closed.
<-ctx.Done() <-ctx.Done()
// Close stream
// straightaway.
stream.Close()
// Tidy up underlying websocket connection. // Tidy up underlying websocket connection.
if err := wsConn.Close(); err != nil { if err := wsConn.Close(); err != nil {
l.Errorf("error closing websocket connection: %v", err) l.Errorf("error closing websocket connection: %v", err)
} }
// Close processor channel so the processor knows
// not to send any more messages to this stream.
close(stream.Hangup)
// Stop ping ticker (tiny resource saving).
pinger.Stop()
l.Info("closed websocket connection") l.Info("closed websocket connection")
} }
@ -299,89 +286,64 @@ func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *s
// if the given context is canceled. // if the given context is canceled.
func (m *Module) readFromWSConn( func (m *Module) readFromWSConn(
ctx context.Context, ctx context.Context,
username string,
wsConn *websocket.Conn, wsConn *websocket.Conn,
stream *streampkg.Stream, stream *streampkg.Stream,
l *log.Entry,
) { ) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
readLoop:
for { for {
select { var msg struct {
case <-ctx.Done(): Type string `json:"type"`
// Connection closed. Stream string `json:"stream"`
break readLoop List string `json:"list,omitempty"`
}
default:
// Read JSON objects from the client and act on them. // Read JSON objects from the client and act on them.
var msg map[string]string
if err := wsConn.ReadJSON(&msg); err != nil { if err := wsConn.ReadJSON(&msg); err != nil {
// Only log an error if something weird happened. // Only log an error if something weird happened.
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
if websocket.IsUnexpectedCloseError(err, []int{ if !websocket.IsCloseError(err, []int{
websocket.CloseNormalClosure, websocket.CloseNormalClosure,
websocket.CloseGoingAway, websocket.CloseGoingAway,
websocket.CloseNoStatusReceived, websocket.CloseNoStatusReceived,
}...) { }...) {
l.Errorf("error reading from websocket: %v", err) l.Errorf("error during websocket read: %v", err)
} }
// The connection is gone; no // The connection is gone; no
// further streaming possible. // further streaming possible.
break readLoop break
} }
// Messages *from* the WS connection are infrequent // Messages *from* the WS connection are infrequent
// and usually interesting, so log this at info. // and usually interesting, so log this at info.
l.Infof("received message from websocket: %v", msg) l.Infof("received websocket message: %+v", msg)
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
updateType, ok := msg["type"]
if !ok {
l.Warn("'type' field not provided")
continue
}
updateStream, ok := msg["stream"]
if !ok {
l.Warn("'stream' field not provided")
continue
}
// Ignore if the updateStreamType is unknown (or missing), // Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations // so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { if !slices.Contains(streampkg.AllStatusTimelines, msg.Stream) {
l.Warnf("unknown 'stream' field: %v", msg) l.Warnf("unknown 'stream' field: %v", msg)
continue continue
} }
updateList, ok := msg["list"] if msg.List != "" {
if ok { // If a list is given, add this to
updateStream += ":" + updateList // the stream name as this is how we
// we track stream types internally.
msg.Stream += ":" + msg.List
} }
switch updateType { switch msg.Type {
case "subscribe": case "subscribe":
stream.Lock() stream.Subscribe(msg.Stream)
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe": case "unsubscribe":
stream.Lock() stream.Unsubscribe(msg.Stream)
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default: default:
l.Warnf("invalid 'type' field: %v", msg) l.Warnf("invalid 'type' field: %v", msg)
} }
} }
}
l.Debug("finished reading from websocket connection") l.Debug("finished websocket read")
} }
// writeToWSConn receives messages coming from the processor via the // writeToWSConn receives messages coming from the processor via the
@ -393,46 +355,47 @@ readLoop:
// if the given context is canceled. // if the given context is canceled.
func (m *Module) writeToWSConn( func (m *Module) writeToWSConn(
ctx context.Context, ctx context.Context,
username string,
wsConn *websocket.Conn, wsConn *websocket.Conn,
stream *streampkg.Stream, stream *streampkg.Stream,
pinger *time.Ticker, ping time.Duration,
l *log.Entry,
) { ) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
writeLoop:
for { for {
select { // Wrap context with timeout to send a ping.
case <-ctx.Done(): pingctx, cncl := context.WithTimeout(ctx, ping)
// Connection closed.
break writeLoop
case msg := <-stream.Messages: // Block on receipt of msg.
// Received a new message from the processor. msg, ok := stream.Recv(pingctx)
l.Tracef("writing message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Debugf("error writing json to websocket: %v", err)
break writeLoop
}
// Reset pinger on successful send, since // Check if cancel because ping.
// we know the connection is still there. pinged := (pingctx.Err() != nil)
pinger.Reset(m.dTicker) cncl()
case <-pinger.C: switch {
// Time to send a keep-alive "ping". case !ok && pinged:
l.Trace("writing ping control message to websocket") // The ping context timed out!
l.Trace("writing websocket ping")
// Wrapped context time-out, send a keep-alive "ping".
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil { if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
l.Debugf("error writing ping to websocket: %v", err) l.Debugf("error writing websocket ping: %v", err)
break writeLoop break
} }
case !ok:
// Stream was
// closed.
return
}
l.Trace("writing websocket message: %+v", msg)
// Received a new message from the processor.
if err := wsConn.WriteJSON(msg); err != nil {
l.Debugf("error writing websocket message: %v", err)
break
} }
} }
l.Debug("finished writing to websocket connection") l.Debug("finished websocket write")
} }

View File

@ -18,38 +18,16 @@
package stream package stream
import ( import (
"fmt" "context"
"strings"
"github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/stream"
) )
// Delete streams the delete of the given statusID to *ALL* open streams. // Delete streams the delete of the given statusID to *ALL* open streams.
func (p *Processor) Delete(statusID string) error { func (p *Processor) Delete(ctx context.Context, statusID string) {
errs := []string{} p.streams.PostAll(ctx, stream.Message{
Payload: statusID,
// get all account IDs with open streams Event: stream.EventTypeDelete,
accountIDs := []string{} Stream: stream.AllStatusTimelines,
p.streamMap.Range(func(k interface{}, _ interface{}) bool {
key, ok := k.(string)
if !ok {
panic("streamMap key was not a string (account id)")
}
accountIDs = append(accountIDs, key)
return true
}) })
// stream the delete to every account
for _, accountID := range accountIDs {
if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) != 0 {
return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";"))
}
return nil
} }

View File

@ -18,20 +18,29 @@
package stream package stream
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt"
"codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/stream"
) )
// Notify streams the given notification to any open, appropriate streams belonging to the given account. // Notify streams the given notification to any open, appropriate streams belonging to the given account.
func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error { func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) {
bytes, err := json.Marshal(n) b, err := json.Marshal(notif)
if err != nil { if err != nil {
return fmt.Errorf("error marshalling notification to json: %s", err) log.Errorf(ctx, "error marshaling json: %v", err)
return
} }
p.streams.Post(ctx, account.ID, stream.Message{
return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID) Payload: byteutil.B2S(b),
Event: stream.EventTypeNotification,
Stream: []string{
stream.TimelineNotifications,
stream.TimelineHome,
},
})
} }

View File

@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() {
Account: followAccountAPIModel, Account: followAccountAPIModel,
} }
err = suite.streamProcessor.Notify(notification, account) suite.streamProcessor.Notify(context.Background(), account, notification)
suite.NoError(err)
msg, ok := openStream.Recv(context.Background())
suite.True(ok)
msg := <-openStream.Messages
dst := new(bytes.Buffer) dst := new(bytes.Buffer)
err = json.Indent(dst, []byte(msg.Payload), "", " ") err = json.Indent(dst, []byte(msg.Payload), "", " ")
suite.NoError(err) suite.NoError(err)

View File

@ -19,13 +19,10 @@ package stream
import ( import (
"context" "context"
"errors"
"fmt"
"codeberg.org/gruf/go-kv" "codeberg.org/gruf/go-kv"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/stream"
) )
@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT
{"streamType", streamType}, {"streamType", streamType},
}...) }...)
l.Debug("received open stream request") l.Debug("received open stream request")
return p.streams.Open(account.ID, streamType), nil
var (
streamID string
err error
)
// Each stream needs a unique ID so we know to close it.
streamID, err = id.NewRandomULID()
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
}
// Each stream can be subscibed to multiple types.
// Record them in a set, and include the initial one
// if it was given to us.
streamTypes := map[string]any{}
if streamType != "" {
streamTypes[streamType] = true
}
newStream := &stream.Stream{
ID: streamID,
StreamTypes: streamTypes,
Messages: make(chan *stream.Message, 100),
Hangup: make(chan interface{}, 1),
Connected: true,
}
go p.waitToCloseStream(account, newStream)
v, ok := p.streamMap.Load(account.ID)
if ok {
// There is an entry in the streamMap
// for this account. Parse it out.
streamsForAccount, ok := v.(*stream.StreamsForAccount)
if !ok {
return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
}
// Append new stream to existing entry.
streamsForAccount.Lock()
streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
streamsForAccount.Unlock()
} else {
// There is no entry in the streamMap for
// this account yet. Create one and store it.
p.streamMap.Store(account.ID, &stream.StreamsForAccount{
Streams: []*stream.Stream{
newStream,
},
})
}
return newStream, nil
}
// waitToCloseStream waits until the hangup channel is closed for the given stream.
// It then iterates through the map of streams stored by the processor, removes the stream from it,
// and then closes the messages channel of the stream to indicate that the channel should no longer be read from.
func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) {
<-thisStream.Hangup // wait for a hangup message
// lock the stream to prevent more messages being put in it while we work
thisStream.Lock()
defer thisStream.Unlock()
// indicate the stream is no longer connected
thisStream.Connected = false
// load and parse the entry for this account from the stream map
v, ok := p.streamMap.Load(account.ID)
if !ok || v == nil {
return
}
streamsForAccount, ok := v.(*stream.StreamsForAccount)
if !ok {
return
}
// lock the streams for account while we remove this stream from its slice
streamsForAccount.Lock()
defer streamsForAccount.Unlock()
// put everything into modified streams *except* the stream we're removing
modifiedStreams := []*stream.Stream{}
for _, s := range streamsForAccount.Streams {
if s.ID != thisStream.ID {
modifiedStreams = append(modifiedStreams, s)
}
}
streamsForAccount.Streams = modifiedStreams
// finally close the messages channel so no more messages can be read from it
close(thisStream.Messages)
} }

View File

@ -18,21 +18,26 @@
package stream package stream
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt"
"codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/stream"
) )
// StatusUpdate streams the given edited status to any open, appropriate // StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account.
// streams belonging to the given account. func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { b, err := json.Marshal(status)
bytes, err := json.Marshal(s)
if err != nil { if err != nil {
return fmt.Errorf("error marshalling status to json: %s", err) log.Errorf(ctx, "error marshaling json: %v", err)
return
} }
p.streams.Post(ctx, account.ID, stream.Message{
return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID) Payload: byteutil.B2S(b),
Event: stream.EventTypeStatusUpdate,
Stream: []string{streamType},
})
} }

View File

@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() {
apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account) apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)
suite.NoError(err) suite.NoError(err)
err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome}) suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome)
suite.NoError(err)
msg, ok := openStream.Recv(context.Background())
suite.True(ok)
msg := <-openStream.Messages
dst := new(bytes.Buffer) dst := new(bytes.Buffer)
err = json.Indent(dst, []byte(msg.Payload), "", " ") err = json.Indent(dst, []byte(msg.Payload), "", " ")
suite.NoError(err) suite.NoError(err)

View File

@ -18,8 +18,6 @@
package stream package stream
import ( import (
"sync"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/stream"
@ -28,53 +26,13 @@ import (
type Processor struct { type Processor struct {
state *state.State state *state.State
oauthServer oauth.Server oauthServer oauth.Server
streamMap *sync.Map streams stream.Streams
} }
func New(state *state.State, oauthServer oauth.Server) Processor { func New(state *state.State, oauthServer oauth.Server) Processor {
return Processor{ return Processor{
state: state, state: state,
oauthServer: oauthServer, oauthServer: oauthServer,
streamMap: &sync.Map{}, streams: stream.Streams{},
} }
} }
// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
// Load all streams open for this account.
v, ok := p.streamMap.Load(accountID)
if !ok {
return nil // No entry = nothing to stream.
}
streamsForAccount := v.(*stream.StreamsForAccount)
streamsForAccount.Lock()
defer streamsForAccount.Unlock()
for _, s := range streamsForAccount.Streams {
s.Lock()
defer s.Unlock()
if !s.Connected {
continue
}
typeLoop:
for _, streamType := range streamTypes {
if _, found := s.StreamTypes[streamType]; found {
s.Messages <- &stream.Message{
Stream: []string{streamType},
Event: string(event),
Payload: payload,
}
// Break out to the outer loop,
// to avoid sending duplicates of
// the same event to the same stream.
break typeLoop
}
}
}
return nil
}

View File

@ -18,20 +18,26 @@
package stream package stream
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt"
"codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/stream"
) )
// Update streams the given update to any open, appropriate streams belonging to the given account. // Update streams the given update to any open, appropriate streams belonging to the given account.
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
bytes, err := json.Marshal(s) b, err := json.Marshal(status)
if err != nil { if err != nil {
return fmt.Errorf("error marshalling status to json: %s", err) log.Errorf(ctx, "error marshaling json: %v", err)
return
} }
p.streams.Post(ctx, account.ID, stream.Message{
return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID) Payload: byteutil.B2S(b),
Event: stream.EventTypeUpdate,
Stream: []string{streamType},
})
} }

View File

@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed(
expectPayload string, expectPayload string,
expectEventType string, expectEventType string,
) { ) {
var msg *stream.Message
streamLoop: // Set a 5s timeout on context.
for { ctx := context.Background()
select { ctx, cncl := context.WithTimeout(ctx, time.Second*5)
case msg = <-str.Messages: defer cncl()
break streamLoop // Got it.
case <-time.After(5 * time.Second): msg, ok := str.Recv(ctx)
break streamLoop // Didn't get it.
} if expectMessage && !ok {
suite.FailNow("expected a message but message was not received")
} }
if expectMessage && msg == nil { if !expectMessage && ok {
suite.FailNow("expected a message but message was nil") suite.FailNow("expected no message but message was received")
}
if !expectMessage && msg != nil {
suite.FailNow("expected no message but message was not nil")
} }
if expectPayload != "" && msg.Payload != expectPayload { if expectPayload != "" && msg.Payload != expectPayload {

View File

@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() {
suite.Equal(replyingStatus.ID, notif.StatusID) suite.Equal(replyingStatus.ID, notif.StatusID)
suite.False(*notif.Read) suite.False(*notif.Read)
// the notification should be streamed ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
var msg *stream.Message msg, ok := wssStream.Recv(ctx)
select { suite.True(ok)
case msg = <-wssStream.Messages:
// fine
case <-time.After(5 * time.Second):
suite.FailNow("no message from wssStream")
}
suite.Equal(stream.EventTypeNotification, msg.Event) suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload) suite.NotEmpty(msg.Payload)
@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
suite.Equal(fave.StatusID, notif.StatusID) suite.Equal(fave.StatusID, notif.StatusID)
suite.False(*notif.Read) suite.False(*notif.Read)
// 2. a notification should be streamed ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
var msg *stream.Message msg, ok := wssStream.Recv(ctx)
select { suite.True(ok)
case msg = <-wssStream.Messages:
// fine
case <-time.After(5 * time.Second):
suite.FailNow("no message from wssStream")
}
suite.Equal(stream.EventTypeNotification, msg.Event) suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload) suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream) suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream)
@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
suite.False(*notif.Read) suite.False(*notif.Read)
// 2. no notification should be streamed to the account that received the fave message, because they weren't the target // 2. no notification should be streamed to the account that received the fave message, because they weren't the target
suite.Empty(wssStream.Messages) ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
_, ok := wssStream.Recv(ctx)
suite.False(ok)
} }
func (suite *FromFediAPITestSuite) TestProcessAccountDelete() { func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {
@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
}) })
suite.NoError(err) suite.NoError(err)
// a notification should be streamed ctx, _ = context.WithTimeout(ctx, time.Second*5)
var msg *stream.Message msg, ok := wssStream.Recv(context.Background())
select { suite.True(ok)
case msg = <-wssStream.Messages:
// fine
case <-time.After(5 * time.Second):
suite.FailNow("no message from wssStream")
}
suite.Equal(stream.EventTypeNotification, msg.Event) suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload) suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
suite.Equal(originAccount.ID, notif.Account.ID) suite.Equal(originAccount.ID, notif.Account.ID)
// no messages should have been sent out, since we didn't need to federate an accept // no messages should have been sent out, since we didn't need to federate an accept
suite.Empty(suite.httpClient.SentMessages) suite.Empty(&suite.httpClient.SentMessages)
} }
func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() { func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
suite.Equal(originAccount.URI, accept.To) suite.Equal(originAccount.URI, accept.To)
suite.Equal("Accept", accept.Type) suite.Equal("Accept", accept.Type)
// a notification should be streamed ctx, _ = context.WithTimeout(ctx, time.Second*5)
var msg *stream.Message msg, ok := wssStream.Recv(context.Background())
select { suite.True(ok)
case msg = <-wssStream.Messages:
// fine
case <-time.After(5 * time.Second):
suite.FailNow("no message from wssStream")
}
suite.Equal(stream.EventTypeNotification, msg.Event) suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload) suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)

View File

@ -394,10 +394,7 @@ func (s *surface) notify(
if err != nil { if err != nil {
return gtserror.Newf("error converting notification to api representation: %w", err) return gtserror.Newf("error converting notification to api representation: %w", err)
} }
s.stream.Notify(ctx, targetAccount, apiNotif)
if err := s.stream.Notify(apiNotif, targetAccount); err != nil {
return gtserror.Newf("error streaming notification to account: %w", err)
}
return nil return nil
} }

View File

@ -348,11 +348,7 @@ func (s *surface) timelineStatus(
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
return true, err return true, err
} }
s.stream.Update(ctx, account, apiStatus, streamType)
if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil {
err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
return true, err
}
return true, nil return true, nil
} }
@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string
if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil { if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil {
return err return err
} }
if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil { if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil {
return err return err
} }
s.stream.Delete(ctx, statusID)
return s.stream.Delete(statusID) return nil
} }
// invalidateStatusFromTimelines does cache invalidation on the given status by // invalidateStatusFromTimelines does cache invalidation on the given status by
@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate(
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
return err return err
} }
s.stream.StatusUpdate(ctx, account, apiStatus, streamType)
if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil {
err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
return err
}
return nil return nil
} }

View File

@ -17,36 +17,65 @@
package stream package stream
import "sync" import (
"context"
const ( "maps"
// EventTypeNotification -- a user should be shown a notification "slices"
EventTypeNotification string = "notification" "sync"
// EventTypeUpdate -- a user should be shown an update in their timeline "sync/atomic"
EventTypeUpdate string = "update"
// EventTypeDelete -- something should be deleted from a user
EventTypeDelete string = "delete"
// EventTypeStatusUpdate -- something in the user's timeline has been edited
// (yes this is a confusing name, blame Mastodon)
EventTypeStatusUpdate string = "status.update"
) )
const ( const (
// TimelineLocal -- public statuses from the LOCAL timeline. // EventTypeNotification -- a user
TimelineLocal string = "public:local" // should be shown a notification.
// TimelinePublic -- public statuses, including federated ones. EventTypeNotification = "notification"
TimelinePublic string = "public"
// TimelineHome -- statuses for a user's Home timeline. // EventTypeUpdate -- a user should
TimelineHome string = "user" // be shown an update in their timeline.
// TimelineNotifications -- notification events. EventTypeUpdate = "update"
TimelineNotifications string = "user:notification"
// TimelineDirect -- statuses sent to a user directly. // EventTypeDelete -- something
TimelineDirect string = "direct" // should be deleted from a user.
// TimelineList -- statuses for a user's list timeline. EventTypeDelete = "delete"
TimelineList string = "list"
// EventTypeStatusUpdate -- something in the
// user's timeline has been edited (yes this
// is a confusing name, blame Mastodon ...).
EventTypeStatusUpdate = "status.update"
) )
// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes. const (
// TimelineLocal:
// All public posts originating from this
// server. Analogous to the local timeline.
TimelineLocal = "public:local"
// TimelinePublic:
// All public posts known to the server.
// Analogous to the federated timeline.
TimelinePublic = "public"
// TimelineHome:
// Events related to the current user, such
// as home feed updates and notifications.
TimelineHome = "user"
// TimelineNotifications:
// Notifications for the current user.
TimelineNotifications = "user:notification"
// TimelineDirect:
// Updates to direct conversations.
TimelineDirect = "direct"
// TimelineList:
// Updates to a specific list.
TimelineList = "list"
)
// AllStatusTimelines contains all Timelines
// that a status could conceivably be delivered
// to, useful for sending out status deletes.
var AllStatusTimelines = []string{ var AllStatusTimelines = []string{
TimelineLocal, TimelineLocal,
TimelinePublic, TimelinePublic,
@ -55,38 +84,298 @@ var AllStatusTimelines = []string{
TimelineList, TimelineList,
} }
// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. type Streams struct {
// TODO: put a limit on this streams map[string][]*Stream
type StreamsForAccount struct { mutex sync.Mutex
// The currently held streams for this account
Streams []*Stream
// Mutex to lock/unlock when modifying the slice of streams.
sync.Mutex
} }
// Stream represents one open stream for a client. // Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream.
func (s *Streams) Open(accountID string, streamTypes ...string) *Stream {
if len(streamTypes) == 0 {
panic("no stream types given")
}
// Prep new Stream.
str := new(Stream)
str.done = make(chan struct{})
str.msgCh = make(chan Message, 50) // TODO: make configurable
for _, streamType := range streamTypes {
str.Subscribe(streamType)
}
// TODO: add configurable
// max streams per account.
// Acquire lock.
s.mutex.Lock()
if s.streams == nil {
// Main stream-map needs allocating.
s.streams = make(map[string][]*Stream)
}
// Add new stream for account.
strs := s.streams[accountID]
strs = append(strs, str)
s.streams[accountID] = strs
// Register close callback
// to remove stream from our
// internal map for this account.
str.close = func() {
s.mutex.Lock()
strs := s.streams[accountID]
strs = slices.DeleteFunc(strs, func(s *Stream) bool {
return s == str // remove 'str' ptr
})
s.streams[accountID] = strs
s.mutex.Unlock()
}
// Done with lock.
s.mutex.Unlock()
return str
}
// Post will post the given message to all streams of given account ID matching type.
func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool {
var deferred []func() bool
// Acquire lock.
s.mutex.Lock()
// Iterate all streams stored for account.
for _, str := range s.streams[accountID] {
// Check whether stream supports any of our message targets.
if stype := str.getStreamType(msg.Stream...); stype != "" {
// Rescope var
// to prevent
// ptr reuse.
stream := str
// Use a message copy to *only*
// include the supported stream.
msgCopy := Message{
Stream: []string{stype},
Event: msg.Event,
Payload: msg.Payload,
}
// Send message to supported stream
// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
// This prevents deadlocks between each
// msg channel and main Streams{} mutex.
deferred = append(deferred, func() bool {
return stream.send(ctx, msgCopy)
})
}
}
// Done with lock.
s.mutex.Unlock()
var ok bool
// Execute deferred outside lock.
for _, deferfn := range deferred {
v := deferfn()
ok = ok && v
}
return ok
}
// PostAll will post the given message to all streams with matching types.
func (s *Streams) PostAll(ctx context.Context, msg Message) bool {
var deferred []func() bool
// Acquire lock.
s.mutex.Lock()
// Iterate ALL stored streams.
for _, strs := range s.streams {
for _, str := range strs {
// Check whether stream supports any of our message targets.
if stype := str.getStreamType(msg.Stream...); stype != "" {
// Rescope var
// to prevent
// ptr reuse.
stream := str
// Use a message copy to *only*
// include the supported stream.
msgCopy := Message{
Stream: []string{stype},
Event: msg.Event,
Payload: msg.Payload,
}
// Send message to supported stream
// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
// This prevents deadlocks between each
// msg channel and main Streams{} mutex.
deferred = append(deferred, func() bool {
return stream.send(ctx, msgCopy)
})
}
}
}
// Done with lock.
s.mutex.Unlock()
var ok bool
// Execute deferred outside lock.
for _, deferfn := range deferred {
v := deferfn()
ok = ok && v
}
return ok
}
// Stream represents one
// open stream for a client.
type Stream struct { type Stream struct {
// ID of this stream, generated during creation.
ID string // atomically updated ptr to a read-only copy
// A set of types subscribed to by this stream: user/public/etc. // of supported stream types in a hashmap. this
// It's a map to ensure no duplicates; the value is ignored. // gets updated via CAS operations in .cas().
StreamTypes map[string]any types atomic.Pointer[map[string]struct{}]
// Channel of messages for the client to read from
Messages chan *Message // protects stream close.
// Channel to close when the client drops away done chan struct{}
Hangup chan interface{}
// Only put messages in the stream when Connected // inbound msg ch.
Connected bool msgCh chan Message
// Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc.
sync.Mutex // close hook to remove
// stream from Streams{}.
close func()
} }
// Message represents one streamed message. // Subscribe will add given type to given types this stream supports.
func (s *Stream) Subscribe(streamType string) {
s.cas(func(m map[string]struct{}) bool {
if _, ok := m[streamType]; ok {
return false
}
m[streamType] = struct{}{}
return true
})
}
// Unsubscribe will remove given type (if found) from types this stream supports.
func (s *Stream) Unsubscribe(streamType string) {
s.cas(func(m map[string]struct{}) bool {
if _, ok := m[streamType]; !ok {
return false
}
delete(m, streamType)
return true
})
}
// getStreamType returns the first stream type in given list that stream supports.
func (s *Stream) getStreamType(streamTypes ...string) string {
if ptr := s.types.Load(); ptr != nil {
for _, streamType := range streamTypes {
if _, ok := (*ptr)[streamType]; ok {
return streamType
}
}
}
return ""
}
// send will block on posting a new Message{}, returning early with
// a false value if provided context is canceled, or stream closed.
func (s *Stream) send(ctx context.Context, msg Message) bool {
select {
case <-s.done:
return false
case <-ctx.Done():
return false
case s.msgCh <- msg:
return true
}
}
// Recv will block on receiving Message{}, returning early with a
// false value if provided context is canceled, or stream closed.
func (s *Stream) Recv(ctx context.Context) (Message, bool) {
select {
case <-s.done:
return Message{}, false
case <-ctx.Done():
return Message{}, false
case msg := <-s.msgCh:
return msg, true
}
}
// Close will close the underlying context, finally
// removing it from the parent Streams per-account-map.
func (s *Stream) Close() {
select {
case <-s.done:
default:
close(s.done)
s.close()
}
}
// cas will perform a Compare And Swap operation on s.types using modifier func.
func (s *Stream) cas(fn func(map[string]struct{}) bool) {
if fn == nil {
panic("nil function")
}
for {
var m map[string]struct{}
// Get current value.
ptr := s.types.Load()
if ptr == nil {
// Allocate new types map.
m = make(map[string]struct{})
} else {
// Clone r-only map.
m = maps.Clone(*ptr)
}
// Apply
// changes.
if !fn(m) {
return
}
// Attempt to Compare And Swap ptr.
if s.types.CompareAndSwap(ptr, &m) {
return
}
}
}
// Message represents
// one streamed message.
type Message struct { type Message struct {
// All the stream types this message should be delivered to.
// All the stream types this
// message should be delivered to.
Stream []string `json:"stream"` Stream []string `json:"stream"`
// The event type of the message (update/delete/notification etc)
// The event type of the message
// (update/delete/notification etc)
Event string `json:"event"` Event string `json:"event"`
// The actual payload of the message. In case of an update or notification, this will be a JSON string.
// The actual payload of the message. In case of an
// update or notification, this will be a JSON string.
Payload string `json:"payload"` Payload string `json:"payload"`
} }