[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code
* use new context for the stream context
* ensure stream gets closed on return of writeTo / readFrom WSConn()
* ensure stream write timeout gets cancelled
* remove embedded context type from Stream{}, reformat log messages for consistency
* use c.Request.Context() for context passed into Stream().Open()
* only return 1 boolean, fix tests to expect multiple stream types in messages
* changes to ping logic
* further improved ping logic
* don't export unused function types, update message sending to only include relevant stream type
* ensure stream gets closed 🤦
* update to error log on failed json marshal (instead of panic)
* inverse websocket read error checking to _ignore_ expected close errors
This commit is contained in:
parent
8cafa6b74b
commit
291e180990
|
@ -22,10 +22,10 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"codeberg.org/gruf/go-kv"
|
|
||||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||||
streampkg "github.com/superseriousbusiness/gotosocial/internal/stream"
|
streampkg "github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||||
|
@ -202,7 +202,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
// functions pass messages into a channel, which we can
|
// functions pass messages into a channel, which we can
|
||||||
// then read from and put into a websockets connection.
|
// then read from and put into a websockets connection.
|
||||||
stream, errWithCode := m.processor.Stream().Open(
|
stream, errWithCode := m.processor.Stream().Open(
|
||||||
c.Request.Context(),
|
c.Request.Context(), // this ctx is only used for logging
|
||||||
account,
|
account,
|
||||||
streamType,
|
streamType,
|
||||||
)
|
)
|
||||||
|
@ -213,10 +213,8 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
|
|
||||||
l := log.
|
l := log.
|
||||||
WithContext(c.Request.Context()).
|
WithContext(c.Request.Context()).
|
||||||
WithFields(kv.Fields{
|
WithField("streamID", id.NewULID()).
|
||||||
{"username", account.Username},
|
WithField("username", account.Username)
|
||||||
{"streamID", stream.ID},
|
|
||||||
}...)
|
|
||||||
|
|
||||||
// Upgrade the incoming HTTP request. This hijacks the
|
// Upgrade the incoming HTTP request. This hijacks the
|
||||||
// underlying connection and reuses it for the websocket
|
// underlying connection and reuses it for the websocket
|
||||||
|
@ -227,18 +225,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
|
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorf("error upgrading websocket connection: %v", err)
|
l.Errorf("error upgrading websocket connection: %v", err)
|
||||||
close(stream.Hangup)
|
stream.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Info("opened websocket connection")
|
|
||||||
|
|
||||||
// We perform the main websocket rw loops in a separate
|
// We perform the main websocket rw loops in a separate
|
||||||
// goroutine in order to let the upgrade handler return.
|
// goroutine in order to let the upgrade handler return.
|
||||||
// This prevents the upgrade handler from holding open any
|
// This prevents the upgrade handler from holding open any
|
||||||
// throttle / rate-limit request tokens which could become
|
// throttle / rate-limit request tokens which could become
|
||||||
// problematic on instances with multiple users.
|
// problematic on instances with multiple users.
|
||||||
go m.handleWSConn(account.Username, wsConn, stream)
|
go m.handleWSConn(&l, wsConn, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleWSConn handles a two-way websocket streaming connection.
|
// handleWSConn handles a two-way websocket streaming connection.
|
||||||
|
@ -246,48 +242,39 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
// into the connection. If any errors are encountered while reading
|
// into the connection. If any errors are encountered while reading
|
||||||
// or writing (including expected errors like clients leaving), the
|
// or writing (including expected errors like clients leaving), the
|
||||||
// connection will be closed.
|
// connection will be closed.
|
||||||
func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
|
func (m *Module) handleWSConn(l *log.Entry, wsConn *websocket.Conn, stream *streampkg.Stream) {
|
||||||
// Create new context for the lifetime of this connection.
|
l.Info("opened websocket connection")
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
l := log.
|
// Create new async context with cancel.
|
||||||
WithContext(ctx).
|
ctx, cncl := context.WithCancel(context.Background())
|
||||||
WithFields(kv.Fields{
|
|
||||||
{"username", username},
|
|
||||||
{"streamID", stream.ID},
|
|
||||||
}...)
|
|
||||||
|
|
||||||
// Create ticker to send keepalive pings
|
|
||||||
pinger := time.NewTicker(m.dTicker)
|
|
||||||
|
|
||||||
// Read messages coming from the Websocket client connection into the server.
|
|
||||||
go func() {
|
go func() {
|
||||||
defer cancel()
|
defer cncl()
|
||||||
m.readFromWSConn(ctx, username, wsConn, stream)
|
|
||||||
|
// Read messages from websocket to server.
|
||||||
|
m.readFromWSConn(ctx, wsConn, stream, l)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Write messages coming from the processor into the Websocket client connection.
|
|
||||||
go func() {
|
go func() {
|
||||||
defer cancel()
|
defer cncl()
|
||||||
m.writeToWSConn(ctx, username, wsConn, stream, pinger)
|
|
||||||
|
// Write messages from processor in websocket conn.
|
||||||
|
m.writeToWSConn(ctx, wsConn, stream, m.dTicker, l)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Wait for either the read or write functions to close, to indicate
|
// Wait for ctx
|
||||||
// that the client has left, or something else has gone wrong.
|
// to be closed.
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
|
||||||
|
// Close stream
|
||||||
|
// straightaway.
|
||||||
|
stream.Close()
|
||||||
|
|
||||||
// Tidy up underlying websocket connection.
|
// Tidy up underlying websocket connection.
|
||||||
if err := wsConn.Close(); err != nil {
|
if err := wsConn.Close(); err != nil {
|
||||||
l.Errorf("error closing websocket connection: %v", err)
|
l.Errorf("error closing websocket connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close processor channel so the processor knows
|
|
||||||
// not to send any more messages to this stream.
|
|
||||||
close(stream.Hangup)
|
|
||||||
|
|
||||||
// Stop ping ticker (tiny resource saving).
|
|
||||||
pinger.Stop()
|
|
||||||
|
|
||||||
l.Info("closed websocket connection")
|
l.Info("closed websocket connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,89 +286,64 @@ func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *s
|
||||||
// if the given context is canceled.
|
// if the given context is canceled.
|
||||||
func (m *Module) readFromWSConn(
|
func (m *Module) readFromWSConn(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
username string,
|
|
||||||
wsConn *websocket.Conn,
|
wsConn *websocket.Conn,
|
||||||
stream *streampkg.Stream,
|
stream *streampkg.Stream,
|
||||||
|
l *log.Entry,
|
||||||
) {
|
) {
|
||||||
l := log.
|
|
||||||
WithContext(ctx).
|
|
||||||
WithFields(kv.Fields{
|
|
||||||
{"username", username},
|
|
||||||
{"streamID", stream.ID},
|
|
||||||
}...)
|
|
||||||
|
|
||||||
readLoop:
|
|
||||||
for {
|
for {
|
||||||
select {
|
var msg struct {
|
||||||
case <-ctx.Done():
|
Type string `json:"type"`
|
||||||
// Connection closed.
|
Stream string `json:"stream"`
|
||||||
break readLoop
|
List string `json:"list,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
|
||||||
// Read JSON objects from the client and act on them.
|
// Read JSON objects from the client and act on them.
|
||||||
var msg map[string]string
|
|
||||||
if err := wsConn.ReadJSON(&msg); err != nil {
|
if err := wsConn.ReadJSON(&msg); err != nil {
|
||||||
// Only log an error if something weird happened.
|
// Only log an error if something weird happened.
|
||||||
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
|
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
|
||||||
if websocket.IsUnexpectedCloseError(err, []int{
|
if !websocket.IsCloseError(err, []int{
|
||||||
websocket.CloseNormalClosure,
|
websocket.CloseNormalClosure,
|
||||||
websocket.CloseGoingAway,
|
websocket.CloseGoingAway,
|
||||||
websocket.CloseNoStatusReceived,
|
websocket.CloseNoStatusReceived,
|
||||||
}...) {
|
}...) {
|
||||||
l.Errorf("error reading from websocket: %v", err)
|
l.Errorf("error during websocket read: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The connection is gone; no
|
// The connection is gone; no
|
||||||
// further streaming possible.
|
// further streaming possible.
|
||||||
break readLoop
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Messages *from* the WS connection are infrequent
|
// Messages *from* the WS connection are infrequent
|
||||||
// and usually interesting, so log this at info.
|
// and usually interesting, so log this at info.
|
||||||
l.Infof("received message from websocket: %v", msg)
|
l.Infof("received websocket message: %+v", msg)
|
||||||
|
|
||||||
// If the message contains 'stream' and 'type' fields, we can
|
|
||||||
// update the set of timelines that are subscribed for events.
|
|
||||||
updateType, ok := msg["type"]
|
|
||||||
if !ok {
|
|
||||||
l.Warn("'type' field not provided")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
updateStream, ok := msg["stream"]
|
|
||||||
if !ok {
|
|
||||||
l.Warn("'stream' field not provided")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore if the updateStreamType is unknown (or missing),
|
// Ignore if the updateStreamType is unknown (or missing),
|
||||||
// so a bad client can't cause extra memory allocations
|
// so a bad client can't cause extra memory allocations
|
||||||
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
|
if !slices.Contains(streampkg.AllStatusTimelines, msg.Stream) {
|
||||||
l.Warnf("unknown 'stream' field: %v", msg)
|
l.Warnf("unknown 'stream' field: %v", msg)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
updateList, ok := msg["list"]
|
if msg.List != "" {
|
||||||
if ok {
|
// If a list is given, add this to
|
||||||
updateStream += ":" + updateList
|
// the stream name as this is how we
|
||||||
|
// we track stream types internally.
|
||||||
|
msg.Stream += ":" + msg.List
|
||||||
}
|
}
|
||||||
|
|
||||||
switch updateType {
|
switch msg.Type {
|
||||||
case "subscribe":
|
case "subscribe":
|
||||||
stream.Lock()
|
stream.Subscribe(msg.Stream)
|
||||||
stream.StreamTypes[updateStream] = true
|
|
||||||
stream.Unlock()
|
|
||||||
case "unsubscribe":
|
case "unsubscribe":
|
||||||
stream.Lock()
|
stream.Unsubscribe(msg.Stream)
|
||||||
delete(stream.StreamTypes, updateStream)
|
|
||||||
stream.Unlock()
|
|
||||||
default:
|
default:
|
||||||
l.Warnf("invalid 'type' field: %v", msg)
|
l.Warnf("invalid 'type' field: %v", msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("finished reading from websocket connection")
|
l.Debug("finished websocket read")
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeToWSConn receives messages coming from the processor via the
|
// writeToWSConn receives messages coming from the processor via the
|
||||||
|
@ -393,46 +355,47 @@ readLoop:
|
||||||
// if the given context is canceled.
|
// if the given context is canceled.
|
||||||
func (m *Module) writeToWSConn(
|
func (m *Module) writeToWSConn(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
username string,
|
|
||||||
wsConn *websocket.Conn,
|
wsConn *websocket.Conn,
|
||||||
stream *streampkg.Stream,
|
stream *streampkg.Stream,
|
||||||
pinger *time.Ticker,
|
ping time.Duration,
|
||||||
|
l *log.Entry,
|
||||||
) {
|
) {
|
||||||
l := log.
|
|
||||||
WithContext(ctx).
|
|
||||||
WithFields(kv.Fields{
|
|
||||||
{"username", username},
|
|
||||||
{"streamID", stream.ID},
|
|
||||||
}...)
|
|
||||||
|
|
||||||
writeLoop:
|
|
||||||
for {
|
for {
|
||||||
select {
|
// Wrap context with timeout to send a ping.
|
||||||
case <-ctx.Done():
|
pingctx, cncl := context.WithTimeout(ctx, ping)
|
||||||
// Connection closed.
|
|
||||||
break writeLoop
|
|
||||||
|
|
||||||
case msg := <-stream.Messages:
|
// Block on receipt of msg.
|
||||||
// Received a new message from the processor.
|
msg, ok := stream.Recv(pingctx)
|
||||||
l.Tracef("writing message to websocket: %+v", msg)
|
|
||||||
if err := wsConn.WriteJSON(msg); err != nil {
|
|
||||||
l.Debugf("error writing json to websocket: %v", err)
|
|
||||||
break writeLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset pinger on successful send, since
|
// Check if cancel because ping.
|
||||||
// we know the connection is still there.
|
pinged := (pingctx.Err() != nil)
|
||||||
pinger.Reset(m.dTicker)
|
cncl()
|
||||||
|
|
||||||
case <-pinger.C:
|
switch {
|
||||||
// Time to send a keep-alive "ping".
|
case !ok && pinged:
|
||||||
l.Trace("writing ping control message to websocket")
|
// The ping context timed out!
|
||||||
|
l.Trace("writing websocket ping")
|
||||||
|
|
||||||
|
// Wrapped context time-out, send a keep-alive "ping".
|
||||||
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
|
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
|
||||||
l.Debugf("error writing ping to websocket: %v", err)
|
l.Debugf("error writing websocket ping: %v", err)
|
||||||
break writeLoop
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case !ok:
|
||||||
|
// Stream was
|
||||||
|
// closed.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Trace("writing websocket message: %+v", msg)
|
||||||
|
|
||||||
|
// Received a new message from the processor.
|
||||||
|
if err := wsConn.WriteJSON(msg); err != nil {
|
||||||
|
l.Debugf("error writing websocket message: %v", err)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Debug("finished writing to websocket connection")
|
l.Debug("finished websocket write")
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,38 +18,16 @@
|
||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"context"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Delete streams the delete of the given statusID to *ALL* open streams.
|
// Delete streams the delete of the given statusID to *ALL* open streams.
|
||||||
func (p *Processor) Delete(statusID string) error {
|
func (p *Processor) Delete(ctx context.Context, statusID string) {
|
||||||
errs := []string{}
|
p.streams.PostAll(ctx, stream.Message{
|
||||||
|
Payload: statusID,
|
||||||
// get all account IDs with open streams
|
Event: stream.EventTypeDelete,
|
||||||
accountIDs := []string{}
|
Stream: stream.AllStatusTimelines,
|
||||||
p.streamMap.Range(func(k interface{}, _ interface{}) bool {
|
|
||||||
key, ok := k.(string)
|
|
||||||
if !ok {
|
|
||||||
panic("streamMap key was not a string (account id)")
|
|
||||||
}
|
|
||||||
|
|
||||||
accountIDs = append(accountIDs, key)
|
|
||||||
return true
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// stream the delete to every account
|
|
||||||
for _, accountID := range accountIDs {
|
|
||||||
if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil {
|
|
||||||
errs = append(errs, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errs) != 0 {
|
|
||||||
return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,20 +18,29 @@
|
||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
|
"codeberg.org/gruf/go-byteutil"
|
||||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Notify streams the given notification to any open, appropriate streams belonging to the given account.
|
// Notify streams the given notification to any open, appropriate streams belonging to the given account.
|
||||||
func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error {
|
func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) {
|
||||||
bytes, err := json.Marshal(n)
|
b, err := json.Marshal(notif)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshalling notification to json: %s", err)
|
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
p.streams.Post(ctx, account.ID, stream.Message{
|
||||||
return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID)
|
Payload: byteutil.B2S(b),
|
||||||
|
Event: stream.EventTypeNotification,
|
||||||
|
Stream: []string{
|
||||||
|
stream.TimelineNotifications,
|
||||||
|
stream.TimelineHome,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() {
|
||||||
Account: followAccountAPIModel,
|
Account: followAccountAPIModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = suite.streamProcessor.Notify(notification, account)
|
suite.streamProcessor.Notify(context.Background(), account, notification)
|
||||||
suite.NoError(err)
|
|
||||||
|
msg, ok := openStream.Recv(context.Background())
|
||||||
|
suite.True(ok)
|
||||||
|
|
||||||
msg := <-openStream.Messages
|
|
||||||
dst := new(bytes.Buffer)
|
dst := new(bytes.Buffer)
|
||||||
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
|
|
|
@ -19,13 +19,10 @@ package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"codeberg.org/gruf/go-kv"
|
"codeberg.org/gruf/go-kv"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||||
)
|
)
|
||||||
|
@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT
|
||||||
{"streamType", streamType},
|
{"streamType", streamType},
|
||||||
}...)
|
}...)
|
||||||
l.Debug("received open stream request")
|
l.Debug("received open stream request")
|
||||||
|
return p.streams.Open(account.ID, streamType), nil
|
||||||
var (
|
|
||||||
streamID string
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
// Each stream needs a unique ID so we know to close it.
|
|
||||||
streamID, err = id.NewRandomULID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each stream can be subscibed to multiple types.
|
|
||||||
// Record them in a set, and include the initial one
|
|
||||||
// if it was given to us.
|
|
||||||
streamTypes := map[string]any{}
|
|
||||||
if streamType != "" {
|
|
||||||
streamTypes[streamType] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
newStream := &stream.Stream{
|
|
||||||
ID: streamID,
|
|
||||||
StreamTypes: streamTypes,
|
|
||||||
Messages: make(chan *stream.Message, 100),
|
|
||||||
Hangup: make(chan interface{}, 1),
|
|
||||||
Connected: true,
|
|
||||||
}
|
|
||||||
go p.waitToCloseStream(account, newStream)
|
|
||||||
|
|
||||||
v, ok := p.streamMap.Load(account.ID)
|
|
||||||
if ok {
|
|
||||||
// There is an entry in the streamMap
|
|
||||||
// for this account. Parse it out.
|
|
||||||
streamsForAccount, ok := v.(*stream.StreamsForAccount)
|
|
||||||
if !ok {
|
|
||||||
return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append new stream to existing entry.
|
|
||||||
streamsForAccount.Lock()
|
|
||||||
streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
|
|
||||||
streamsForAccount.Unlock()
|
|
||||||
} else {
|
|
||||||
// There is no entry in the streamMap for
|
|
||||||
// this account yet. Create one and store it.
|
|
||||||
p.streamMap.Store(account.ID, &stream.StreamsForAccount{
|
|
||||||
Streams: []*stream.Stream{
|
|
||||||
newStream,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return newStream, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitToCloseStream waits until the hangup channel is closed for the given stream.
|
|
||||||
// It then iterates through the map of streams stored by the processor, removes the stream from it,
|
|
||||||
// and then closes the messages channel of the stream to indicate that the channel should no longer be read from.
|
|
||||||
func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) {
|
|
||||||
<-thisStream.Hangup // wait for a hangup message
|
|
||||||
|
|
||||||
// lock the stream to prevent more messages being put in it while we work
|
|
||||||
thisStream.Lock()
|
|
||||||
defer thisStream.Unlock()
|
|
||||||
|
|
||||||
// indicate the stream is no longer connected
|
|
||||||
thisStream.Connected = false
|
|
||||||
|
|
||||||
// load and parse the entry for this account from the stream map
|
|
||||||
v, ok := p.streamMap.Load(account.ID)
|
|
||||||
if !ok || v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
streamsForAccount, ok := v.(*stream.StreamsForAccount)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// lock the streams for account while we remove this stream from its slice
|
|
||||||
streamsForAccount.Lock()
|
|
||||||
defer streamsForAccount.Unlock()
|
|
||||||
|
|
||||||
// put everything into modified streams *except* the stream we're removing
|
|
||||||
modifiedStreams := []*stream.Stream{}
|
|
||||||
for _, s := range streamsForAccount.Streams {
|
|
||||||
if s.ID != thisStream.ID {
|
|
||||||
modifiedStreams = append(modifiedStreams, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
streamsForAccount.Streams = modifiedStreams
|
|
||||||
|
|
||||||
// finally close the messages channel so no more messages can be read from it
|
|
||||||
close(thisStream.Messages)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,21 +18,26 @@
|
||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
|
"codeberg.org/gruf/go-byteutil"
|
||||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusUpdate streams the given edited status to any open, appropriate
|
// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account.
|
||||||
// streams belonging to the given account.
|
func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
|
||||||
func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
|
b, err := json.Marshal(status)
|
||||||
bytes, err := json.Marshal(s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshalling status to json: %s", err)
|
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
p.streams.Post(ctx, account.ID, stream.Message{
|
||||||
return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID)
|
Payload: byteutil.B2S(b),
|
||||||
|
Event: stream.EventTypeStatusUpdate,
|
||||||
|
Stream: []string{streamType},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() {
|
||||||
apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)
|
apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
|
|
||||||
err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome})
|
suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome)
|
||||||
suite.NoError(err)
|
|
||||||
|
msg, ok := openStream.Recv(context.Background())
|
||||||
|
suite.True(ok)
|
||||||
|
|
||||||
msg := <-openStream.Messages
|
|
||||||
dst := new(bytes.Buffer)
|
dst := new(bytes.Buffer)
|
||||||
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
err = json.Indent(dst, []byte(msg.Payload), "", " ")
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
|
|
|
@ -18,8 +18,6 @@
|
||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||||
|
@ -28,53 +26,13 @@ import (
|
||||||
type Processor struct {
|
type Processor struct {
|
||||||
state *state.State
|
state *state.State
|
||||||
oauthServer oauth.Server
|
oauthServer oauth.Server
|
||||||
streamMap *sync.Map
|
streams stream.Streams
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(state *state.State, oauthServer oauth.Server) Processor {
|
func New(state *state.State, oauthServer oauth.Server) Processor {
|
||||||
return Processor{
|
return Processor{
|
||||||
state: state,
|
state: state,
|
||||||
oauthServer: oauthServer,
|
oauthServer: oauthServer,
|
||||||
streamMap: &sync.Map{},
|
streams: stream.Streams{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
|
|
||||||
func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
|
|
||||||
// Load all streams open for this account.
|
|
||||||
v, ok := p.streamMap.Load(accountID)
|
|
||||||
if !ok {
|
|
||||||
return nil // No entry = nothing to stream.
|
|
||||||
}
|
|
||||||
streamsForAccount := v.(*stream.StreamsForAccount)
|
|
||||||
|
|
||||||
streamsForAccount.Lock()
|
|
||||||
defer streamsForAccount.Unlock()
|
|
||||||
|
|
||||||
for _, s := range streamsForAccount.Streams {
|
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
|
|
||||||
if !s.Connected {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
typeLoop:
|
|
||||||
for _, streamType := range streamTypes {
|
|
||||||
if _, found := s.StreamTypes[streamType]; found {
|
|
||||||
s.Messages <- &stream.Message{
|
|
||||||
Stream: []string{streamType},
|
|
||||||
Event: string(event),
|
|
||||||
Payload: payload,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Break out to the outer loop,
|
|
||||||
// to avoid sending duplicates of
|
|
||||||
// the same event to the same stream.
|
|
||||||
break typeLoop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -18,20 +18,26 @@
|
||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
|
"codeberg.org/gruf/go-byteutil"
|
||||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Update streams the given update to any open, appropriate streams belonging to the given account.
|
// Update streams the given update to any open, appropriate streams belonging to the given account.
|
||||||
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
|
func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
|
||||||
bytes, err := json.Marshal(s)
|
b, err := json.Marshal(status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshalling status to json: %s", err)
|
log.Errorf(ctx, "error marshaling json: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
p.streams.Post(ctx, account.ID, stream.Message{
|
||||||
return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID)
|
Payload: byteutil.B2S(b),
|
||||||
|
Event: stream.EventTypeUpdate,
|
||||||
|
Stream: []string{streamType},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed(
|
||||||
expectPayload string,
|
expectPayload string,
|
||||||
expectEventType string,
|
expectEventType string,
|
||||||
) {
|
) {
|
||||||
var msg *stream.Message
|
|
||||||
streamLoop:
|
// Set a 5s timeout on context.
|
||||||
for {
|
ctx := context.Background()
|
||||||
select {
|
ctx, cncl := context.WithTimeout(ctx, time.Second*5)
|
||||||
case msg = <-str.Messages:
|
defer cncl()
|
||||||
break streamLoop // Got it.
|
|
||||||
case <-time.After(5 * time.Second):
|
msg, ok := str.Recv(ctx)
|
||||||
break streamLoop // Didn't get it.
|
|
||||||
}
|
if expectMessage && !ok {
|
||||||
|
suite.FailNow("expected a message but message was not received")
|
||||||
}
|
}
|
||||||
|
|
||||||
if expectMessage && msg == nil {
|
if !expectMessage && ok {
|
||||||
suite.FailNow("expected a message but message was nil")
|
suite.FailNow("expected no message but message was received")
|
||||||
}
|
|
||||||
|
|
||||||
if !expectMessage && msg != nil {
|
|
||||||
suite.FailNow("expected no message but message was not nil")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if expectPayload != "" && msg.Payload != expectPayload {
|
if expectPayload != "" && msg.Payload != expectPayload {
|
||||||
|
|
|
@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() {
|
||||||
suite.Equal(replyingStatus.ID, notif.StatusID)
|
suite.Equal(replyingStatus.ID, notif.StatusID)
|
||||||
suite.False(*notif.Read)
|
suite.False(*notif.Read)
|
||||||
|
|
||||||
// the notification should be streamed
|
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
var msg *stream.Message
|
msg, ok := wssStream.Recv(ctx)
|
||||||
select {
|
suite.True(ok)
|
||||||
case msg = <-wssStream.Messages:
|
|
||||||
// fine
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
suite.FailNow("no message from wssStream")
|
|
||||||
}
|
|
||||||
|
|
||||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||||
suite.NotEmpty(msg.Payload)
|
suite.NotEmpty(msg.Payload)
|
||||||
|
@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
|
||||||
suite.Equal(fave.StatusID, notif.StatusID)
|
suite.Equal(fave.StatusID, notif.StatusID)
|
||||||
suite.False(*notif.Read)
|
suite.False(*notif.Read)
|
||||||
|
|
||||||
// 2. a notification should be streamed
|
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
var msg *stream.Message
|
msg, ok := wssStream.Recv(ctx)
|
||||||
select {
|
suite.True(ok)
|
||||||
case msg = <-wssStream.Messages:
|
|
||||||
// fine
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
suite.FailNow("no message from wssStream")
|
|
||||||
}
|
|
||||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||||
suite.NotEmpty(msg.Payload)
|
suite.NotEmpty(msg.Payload)
|
||||||
suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream)
|
suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream)
|
||||||
|
@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
|
||||||
suite.False(*notif.Read)
|
suite.False(*notif.Read)
|
||||||
|
|
||||||
// 2. no notification should be streamed to the account that received the fave message, because they weren't the target
|
// 2. no notification should be streamed to the account that received the fave message, because they weren't the target
|
||||||
suite.Empty(wssStream.Messages)
|
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
_, ok := wssStream.Recv(ctx)
|
||||||
|
suite.False(ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {
|
func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {
|
||||||
|
@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
|
||||||
})
|
})
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
|
|
||||||
// a notification should be streamed
|
ctx, _ = context.WithTimeout(ctx, time.Second*5)
|
||||||
var msg *stream.Message
|
msg, ok := wssStream.Recv(context.Background())
|
||||||
select {
|
suite.True(ok)
|
||||||
case msg = <-wssStream.Messages:
|
|
||||||
// fine
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
suite.FailNow("no message from wssStream")
|
|
||||||
}
|
|
||||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||||
suite.NotEmpty(msg.Payload)
|
suite.NotEmpty(msg.Payload)
|
||||||
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
|
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
|
||||||
|
@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
|
||||||
suite.Equal(originAccount.ID, notif.Account.ID)
|
suite.Equal(originAccount.ID, notif.Account.ID)
|
||||||
|
|
||||||
// no messages should have been sent out, since we didn't need to federate an accept
|
// no messages should have been sent out, since we didn't need to federate an accept
|
||||||
suite.Empty(suite.httpClient.SentMessages)
|
suite.Empty(&suite.httpClient.SentMessages)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
|
func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
|
||||||
|
@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
|
||||||
suite.Equal(originAccount.URI, accept.To)
|
suite.Equal(originAccount.URI, accept.To)
|
||||||
suite.Equal("Accept", accept.Type)
|
suite.Equal("Accept", accept.Type)
|
||||||
|
|
||||||
// a notification should be streamed
|
ctx, _ = context.WithTimeout(ctx, time.Second*5)
|
||||||
var msg *stream.Message
|
msg, ok := wssStream.Recv(context.Background())
|
||||||
select {
|
suite.True(ok)
|
||||||
case msg = <-wssStream.Messages:
|
|
||||||
// fine
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
suite.FailNow("no message from wssStream")
|
|
||||||
}
|
|
||||||
suite.Equal(stream.EventTypeNotification, msg.Event)
|
suite.Equal(stream.EventTypeNotification, msg.Event)
|
||||||
suite.NotEmpty(msg.Payload)
|
suite.NotEmpty(msg.Payload)
|
||||||
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
|
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
|
||||||
|
|
|
@ -394,10 +394,7 @@ func (s *surface) notify(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return gtserror.Newf("error converting notification to api representation: %w", err)
|
return gtserror.Newf("error converting notification to api representation: %w", err)
|
||||||
}
|
}
|
||||||
|
s.stream.Notify(ctx, targetAccount, apiNotif)
|
||||||
if err := s.stream.Notify(apiNotif, targetAccount); err != nil {
|
|
||||||
return gtserror.Newf("error streaming notification to account: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -348,11 +348,7 @@ func (s *surface) timelineStatus(
|
||||||
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
|
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
s.stream.Update(ctx, account, apiStatus, streamType)
|
||||||
if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil {
|
|
||||||
err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string
|
||||||
if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil {
|
if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil {
|
if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
s.stream.Delete(ctx, statusID)
|
||||||
return s.stream.Delete(statusID)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalidateStatusFromTimelines does cache invalidation on the given status by
|
// invalidateStatusFromTimelines does cache invalidation on the given status by
|
||||||
|
@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate(
|
||||||
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
|
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
s.stream.StatusUpdate(ctx, account, apiStatus, streamType)
|
||||||
if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil {
|
|
||||||
err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,36 +17,65 @@
|
||||||
|
|
||||||
package stream
|
package stream
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"context"
|
||||||
const (
|
"maps"
|
||||||
// EventTypeNotification -- a user should be shown a notification
|
"slices"
|
||||||
EventTypeNotification string = "notification"
|
"sync"
|
||||||
// EventTypeUpdate -- a user should be shown an update in their timeline
|
"sync/atomic"
|
||||||
EventTypeUpdate string = "update"
|
|
||||||
// EventTypeDelete -- something should be deleted from a user
|
|
||||||
EventTypeDelete string = "delete"
|
|
||||||
// EventTypeStatusUpdate -- something in the user's timeline has been edited
|
|
||||||
// (yes this is a confusing name, blame Mastodon)
|
|
||||||
EventTypeStatusUpdate string = "status.update"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// TimelineLocal -- public statuses from the LOCAL timeline.
|
// EventTypeNotification -- a user
|
||||||
TimelineLocal string = "public:local"
|
// should be shown a notification.
|
||||||
// TimelinePublic -- public statuses, including federated ones.
|
EventTypeNotification = "notification"
|
||||||
TimelinePublic string = "public"
|
|
||||||
// TimelineHome -- statuses for a user's Home timeline.
|
// EventTypeUpdate -- a user should
|
||||||
TimelineHome string = "user"
|
// be shown an update in their timeline.
|
||||||
// TimelineNotifications -- notification events.
|
EventTypeUpdate = "update"
|
||||||
TimelineNotifications string = "user:notification"
|
|
||||||
// TimelineDirect -- statuses sent to a user directly.
|
// EventTypeDelete -- something
|
||||||
TimelineDirect string = "direct"
|
// should be deleted from a user.
|
||||||
// TimelineList -- statuses for a user's list timeline.
|
EventTypeDelete = "delete"
|
||||||
TimelineList string = "list"
|
|
||||||
|
// EventTypeStatusUpdate -- something in the
|
||||||
|
// user's timeline has been edited (yes this
|
||||||
|
// is a confusing name, blame Mastodon ...).
|
||||||
|
EventTypeStatusUpdate = "status.update"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes.
|
const (
|
||||||
|
// TimelineLocal:
|
||||||
|
// All public posts originating from this
|
||||||
|
// server. Analogous to the local timeline.
|
||||||
|
TimelineLocal = "public:local"
|
||||||
|
|
||||||
|
// TimelinePublic:
|
||||||
|
// All public posts known to the server.
|
||||||
|
// Analogous to the federated timeline.
|
||||||
|
TimelinePublic = "public"
|
||||||
|
|
||||||
|
// TimelineHome:
|
||||||
|
// Events related to the current user, such
|
||||||
|
// as home feed updates and notifications.
|
||||||
|
TimelineHome = "user"
|
||||||
|
|
||||||
|
// TimelineNotifications:
|
||||||
|
// Notifications for the current user.
|
||||||
|
TimelineNotifications = "user:notification"
|
||||||
|
|
||||||
|
// TimelineDirect:
|
||||||
|
// Updates to direct conversations.
|
||||||
|
TimelineDirect = "direct"
|
||||||
|
|
||||||
|
// TimelineList:
|
||||||
|
// Updates to a specific list.
|
||||||
|
TimelineList = "list"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AllStatusTimelines contains all Timelines
|
||||||
|
// that a status could conceivably be delivered
|
||||||
|
// to, useful for sending out status deletes.
|
||||||
var AllStatusTimelines = []string{
|
var AllStatusTimelines = []string{
|
||||||
TimelineLocal,
|
TimelineLocal,
|
||||||
TimelinePublic,
|
TimelinePublic,
|
||||||
|
@ -55,38 +84,298 @@ var AllStatusTimelines = []string{
|
||||||
TimelineList,
|
TimelineList,
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time.
|
type Streams struct {
|
||||||
// TODO: put a limit on this
|
streams map[string][]*Stream
|
||||||
type StreamsForAccount struct {
|
mutex sync.Mutex
|
||||||
// The currently held streams for this account
|
|
||||||
Streams []*Stream
|
|
||||||
// Mutex to lock/unlock when modifying the slice of streams.
|
|
||||||
sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream represents one open stream for a client.
|
// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream.
|
||||||
|
func (s *Streams) Open(accountID string, streamTypes ...string) *Stream {
|
||||||
|
if len(streamTypes) == 0 {
|
||||||
|
panic("no stream types given")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prep new Stream.
|
||||||
|
str := new(Stream)
|
||||||
|
str.done = make(chan struct{})
|
||||||
|
str.msgCh = make(chan Message, 50) // TODO: make configurable
|
||||||
|
for _, streamType := range streamTypes {
|
||||||
|
str.Subscribe(streamType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add configurable
|
||||||
|
// max streams per account.
|
||||||
|
|
||||||
|
// Acquire lock.
|
||||||
|
s.mutex.Lock()
|
||||||
|
|
||||||
|
if s.streams == nil {
|
||||||
|
// Main stream-map needs allocating.
|
||||||
|
s.streams = make(map[string][]*Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new stream for account.
|
||||||
|
strs := s.streams[accountID]
|
||||||
|
strs = append(strs, str)
|
||||||
|
s.streams[accountID] = strs
|
||||||
|
|
||||||
|
// Register close callback
|
||||||
|
// to remove stream from our
|
||||||
|
// internal map for this account.
|
||||||
|
str.close = func() {
|
||||||
|
s.mutex.Lock()
|
||||||
|
strs := s.streams[accountID]
|
||||||
|
strs = slices.DeleteFunc(strs, func(s *Stream) bool {
|
||||||
|
return s == str // remove 'str' ptr
|
||||||
|
})
|
||||||
|
s.streams[accountID] = strs
|
||||||
|
s.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done with lock.
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post will post the given message to all streams of given account ID matching type.
|
||||||
|
func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool {
|
||||||
|
var deferred []func() bool
|
||||||
|
|
||||||
|
// Acquire lock.
|
||||||
|
s.mutex.Lock()
|
||||||
|
|
||||||
|
// Iterate all streams stored for account.
|
||||||
|
for _, str := range s.streams[accountID] {
|
||||||
|
|
||||||
|
// Check whether stream supports any of our message targets.
|
||||||
|
if stype := str.getStreamType(msg.Stream...); stype != "" {
|
||||||
|
|
||||||
|
// Rescope var
|
||||||
|
// to prevent
|
||||||
|
// ptr reuse.
|
||||||
|
stream := str
|
||||||
|
|
||||||
|
// Use a message copy to *only*
|
||||||
|
// include the supported stream.
|
||||||
|
msgCopy := Message{
|
||||||
|
Stream: []string{stype},
|
||||||
|
Event: msg.Event,
|
||||||
|
Payload: msg.Payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send message to supported stream
|
||||||
|
// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
|
||||||
|
// This prevents deadlocks between each
|
||||||
|
// msg channel and main Streams{} mutex.
|
||||||
|
deferred = append(deferred, func() bool {
|
||||||
|
return stream.send(ctx, msgCopy)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done with lock.
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// Execute deferred outside lock.
|
||||||
|
for _, deferfn := range deferred {
|
||||||
|
v := deferfn()
|
||||||
|
ok = ok && v
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostAll will post the given message to all streams with matching types.
|
||||||
|
func (s *Streams) PostAll(ctx context.Context, msg Message) bool {
|
||||||
|
var deferred []func() bool
|
||||||
|
|
||||||
|
// Acquire lock.
|
||||||
|
s.mutex.Lock()
|
||||||
|
|
||||||
|
// Iterate ALL stored streams.
|
||||||
|
for _, strs := range s.streams {
|
||||||
|
for _, str := range strs {
|
||||||
|
|
||||||
|
// Check whether stream supports any of our message targets.
|
||||||
|
if stype := str.getStreamType(msg.Stream...); stype != "" {
|
||||||
|
|
||||||
|
// Rescope var
|
||||||
|
// to prevent
|
||||||
|
// ptr reuse.
|
||||||
|
stream := str
|
||||||
|
|
||||||
|
// Use a message copy to *only*
|
||||||
|
// include the supported stream.
|
||||||
|
msgCopy := Message{
|
||||||
|
Stream: []string{stype},
|
||||||
|
Event: msg.Event,
|
||||||
|
Payload: msg.Payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send message to supported stream
|
||||||
|
// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
|
||||||
|
// This prevents deadlocks between each
|
||||||
|
// msg channel and main Streams{} mutex.
|
||||||
|
deferred = append(deferred, func() bool {
|
||||||
|
return stream.send(ctx, msgCopy)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done with lock.
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// Execute deferred outside lock.
|
||||||
|
for _, deferfn := range deferred {
|
||||||
|
v := deferfn()
|
||||||
|
ok = ok && v
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream represents one
|
||||||
|
// open stream for a client.
|
||||||
type Stream struct {
|
type Stream struct {
|
||||||
// ID of this stream, generated during creation.
|
|
||||||
ID string
|
// atomically updated ptr to a read-only copy
|
||||||
// A set of types subscribed to by this stream: user/public/etc.
|
// of supported stream types in a hashmap. this
|
||||||
// It's a map to ensure no duplicates; the value is ignored.
|
// gets updated via CAS operations in .cas().
|
||||||
StreamTypes map[string]any
|
types atomic.Pointer[map[string]struct{}]
|
||||||
// Channel of messages for the client to read from
|
|
||||||
Messages chan *Message
|
// protects stream close.
|
||||||
// Channel to close when the client drops away
|
done chan struct{}
|
||||||
Hangup chan interface{}
|
|
||||||
// Only put messages in the stream when Connected
|
// inbound msg ch.
|
||||||
Connected bool
|
msgCh chan Message
|
||||||
// Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc.
|
|
||||||
sync.Mutex
|
// close hook to remove
|
||||||
|
// stream from Streams{}.
|
||||||
|
close func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message represents one streamed message.
|
// Subscribe will add given type to given types this stream supports.
|
||||||
|
func (s *Stream) Subscribe(streamType string) {
|
||||||
|
s.cas(func(m map[string]struct{}) bool {
|
||||||
|
if _, ok := m[streamType]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m[streamType] = struct{}{}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe will remove given type (if found) from types this stream supports.
|
||||||
|
func (s *Stream) Unsubscribe(streamType string) {
|
||||||
|
s.cas(func(m map[string]struct{}) bool {
|
||||||
|
if _, ok := m[streamType]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(m, streamType)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStreamType returns the first stream type in given list that stream supports.
|
||||||
|
func (s *Stream) getStreamType(streamTypes ...string) string {
|
||||||
|
if ptr := s.types.Load(); ptr != nil {
|
||||||
|
for _, streamType := range streamTypes {
|
||||||
|
if _, ok := (*ptr)[streamType]; ok {
|
||||||
|
return streamType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// send will block on posting a new Message{}, returning early with
|
||||||
|
// a false value if provided context is canceled, or stream closed.
|
||||||
|
func (s *Stream) send(ctx context.Context, msg Message) bool {
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
return false
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case s.msgCh <- msg:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recv will block on receiving Message{}, returning early with a
|
||||||
|
// false value if provided context is canceled, or stream closed.
|
||||||
|
func (s *Stream) Recv(ctx context.Context) (Message, bool) {
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
return Message{}, false
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Message{}, false
|
||||||
|
case msg := <-s.msgCh:
|
||||||
|
return msg, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close will close the underlying context, finally
|
||||||
|
// removing it from the parent Streams per-account-map.
|
||||||
|
func (s *Stream) Close() {
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
default:
|
||||||
|
close(s.done)
|
||||||
|
s.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cas will perform a Compare And Swap operation on s.types using modifier func.
|
||||||
|
func (s *Stream) cas(fn func(map[string]struct{}) bool) {
|
||||||
|
if fn == nil {
|
||||||
|
panic("nil function")
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
var m map[string]struct{}
|
||||||
|
|
||||||
|
// Get current value.
|
||||||
|
ptr := s.types.Load()
|
||||||
|
|
||||||
|
if ptr == nil {
|
||||||
|
// Allocate new types map.
|
||||||
|
m = make(map[string]struct{})
|
||||||
|
} else {
|
||||||
|
// Clone r-only map.
|
||||||
|
m = maps.Clone(*ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply
|
||||||
|
// changes.
|
||||||
|
if !fn(m) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to Compare And Swap ptr.
|
||||||
|
if s.types.CompareAndSwap(ptr, &m) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message represents
|
||||||
|
// one streamed message.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
// All the stream types this message should be delivered to.
|
|
||||||
|
// All the stream types this
|
||||||
|
// message should be delivered to.
|
||||||
Stream []string `json:"stream"`
|
Stream []string `json:"stream"`
|
||||||
// The event type of the message (update/delete/notification etc)
|
|
||||||
|
// The event type of the message
|
||||||
|
// (update/delete/notification etc)
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
// The actual payload of the message. In case of an update or notification, this will be a JSON string.
|
|
||||||
|
// The actual payload of the message. In case of an
|
||||||
|
// update or notification, this will be a JSON string.
|
||||||
Payload string `json:"payload"`
|
Payload string `json:"payload"`
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue