mirror of
1
Fork 0

[bugfix] self-referencing collection pages for status replies (#2364)

This commit is contained in:
kim 2023-11-20 12:22:28 +00:00 committed by GitHub
parent efefdb1323
commit 16275853eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 611 additions and 427 deletions

View File

@ -49,6 +49,7 @@ func TestASCollection(t *testing.T) {
// Create new collection using builder function. // Create new collection using builder function.
c := ap.NewASCollection(ap.CollectionParams{ c := ap.NewASCollection(ap.CollectionParams{
ID: parseURI(idURI), ID: parseURI(idURI),
Query: url.Values{"limit": []string{"40"}},
Total: total, Total: total,
}) })
@ -56,7 +57,7 @@ func TestASCollection(t *testing.T) {
s := toJSON(c) s := toJSON(c)
// Ensure outputs are equal. // Ensure outputs are equal.
assert.Equal(t, s, expect) assert.Equal(t, expect, s)
} }
func TestASCollectionPage(t *testing.T) { func TestASCollectionPage(t *testing.T) {
@ -110,7 +111,7 @@ func TestASCollectionPage(t *testing.T) {
s := toJSON(p) s := toJSON(p)
// Ensure outputs are equal. // Ensure outputs are equal.
assert.Equal(t, s, expect) assert.Equal(t, expect, s)
} }
func TestASOrderedCollection(t *testing.T) { func TestASOrderedCollection(t *testing.T) {
@ -131,6 +132,7 @@ func TestASOrderedCollection(t *testing.T) {
// Create new collection using builder function. // Create new collection using builder function.
c := ap.NewASOrderedCollection(ap.CollectionParams{ c := ap.NewASOrderedCollection(ap.CollectionParams{
ID: parseURI(idURI), ID: parseURI(idURI),
Query: url.Values{"limit": []string{"40"}},
Total: total, Total: total,
}) })
@ -138,7 +140,7 @@ func TestASOrderedCollection(t *testing.T) {
s := toJSON(c) s := toJSON(c)
// Ensure outputs are equal. // Ensure outputs are equal.
assert.Equal(t, s, expect) assert.Equal(t, expect, s)
} }
func TestASOrderedCollectionPage(t *testing.T) { func TestASOrderedCollectionPage(t *testing.T) {
@ -192,7 +194,7 @@ func TestASOrderedCollectionPage(t *testing.T) {
s := toJSON(p) s := toJSON(p)
// Ensure outputs are equal. // Ensure outputs are equal.
assert.Equal(t, s, expect) assert.Equal(t, expect, s)
} }
func parseURI(s string) *url.URL { func parseURI(s string) *url.URL {

View File

@ -20,7 +20,6 @@ package ap
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"strconv"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
@ -169,6 +168,10 @@ type CollectionParams struct {
// ID (i.e. NOT the page). // ID (i.e. NOT the page).
ID *url.URL ID *url.URL
// First page details.
First paging.Page
Query url.Values
// Total no. items. // Total no. items.
Total int Total int
} }
@ -224,7 +227,7 @@ type ItemsPropertyBuilder interface {
// NewASCollection builds and returns a new ActivityStreams Collection from given parameters. // NewASCollection builds and returns a new ActivityStreams Collection from given parameters.
func NewASCollection(params CollectionParams) vocab.ActivityStreamsCollection { func NewASCollection(params CollectionParams) vocab.ActivityStreamsCollection {
collection := streams.NewActivityStreamsCollection() collection := streams.NewActivityStreamsCollection()
buildCollection(collection, params, 40) buildCollection(collection, params)
return collection return collection
} }
@ -239,7 +242,7 @@ func NewASCollectionPage(params CollectionPageParams) vocab.ActivityStreamsColle
// NewASOrderedCollection builds and returns a new ActivityStreams OrderedCollection from given parameters. // NewASOrderedCollection builds and returns a new ActivityStreams OrderedCollection from given parameters.
func NewASOrderedCollection(params CollectionParams) vocab.ActivityStreamsOrderedCollection { func NewASOrderedCollection(params CollectionParams) vocab.ActivityStreamsOrderedCollection {
collection := streams.NewActivityStreamsOrderedCollection() collection := streams.NewActivityStreamsOrderedCollection()
buildCollection(collection, params, 40) buildCollection(collection, params)
return collection return collection
} }
@ -251,7 +254,7 @@ func NewASOrderedCollectionPage(params CollectionPageParams) vocab.ActivityStrea
return collectionPage return collectionPage
} }
func buildCollection[C CollectionBuilder](collection C, params CollectionParams, pageLimit int) { func buildCollection[C CollectionBuilder](collection C, params CollectionParams) {
// Add the collection ID property. // Add the collection ID property.
idProp := streams.NewJSONLDIdProperty() idProp := streams.NewJSONLDIdProperty()
idProp.SetIRI(params.ID) idProp.SetIRI(params.ID)
@ -262,15 +265,20 @@ func buildCollection[C CollectionBuilder](collection C, params CollectionParams,
totalItems.Set(params.Total) totalItems.Set(params.Total)
collection.SetActivityStreamsTotalItems(totalItems) collection.SetActivityStreamsTotalItems(totalItems)
// Clone the collection ID page // Append paging query params
// to add first page query data. // to those already in ID prop.
firstIRI := new(url.URL) pageQueryParams := appendQuery(
*firstIRI = *params.ID params.Query,
params.ID.Query(),
)
// Note that simply adding a limit signals to our // Build the first page link IRI.
// endpoint to use paging (which will start at beginning). firstIRI := params.First.ToLinkURL(
limit := "limit=" + strconv.Itoa(pageLimit) params.ID.Scheme,
firstIRI.RawQuery = appendQuery(firstIRI.RawQuery, limit) params.ID.Host,
params.ID.Path,
pageQueryParams,
)
// Add the collection first IRI property. // Add the collection first IRI property.
first := streams.NewActivityStreamsFirstProperty() first := streams.NewActivityStreamsFirstProperty()
@ -284,12 +292,19 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec
partOfProp.SetIRI(params.ID) partOfProp.SetIRI(params.ID)
collectionPage.SetActivityStreamsPartOf(partOfProp) collectionPage.SetActivityStreamsPartOf(partOfProp)
// Append paging query params
// to those already in ID prop.
pageQueryParams := appendQuery(
params.Query,
params.ID.Query(),
)
// Build the current page link IRI. // Build the current page link IRI.
currentIRI := params.Current.ToLinkURL( currentIRI := params.Current.ToLinkURL(
params.ID.Scheme, params.ID.Scheme,
params.ID.Host, params.ID.Host,
params.ID.Path, params.ID.Path,
params.Query, pageQueryParams,
) )
// Add the collection ID property for // Add the collection ID property for
@ -303,7 +318,7 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec
params.ID.Scheme, params.ID.Scheme,
params.ID.Host, params.ID.Host,
params.ID.Path, params.ID.Path,
params.Query, pageQueryParams,
) )
if nextIRI != nil { if nextIRI != nil {
@ -318,7 +333,7 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec
params.ID.Scheme, params.ID.Scheme,
params.ID.Host, params.ID.Host,
params.ID.Path, params.ID.Path,
params.Query, pageQueryParams,
) )
if prevIRI != nil { if prevIRI != nil {
@ -349,11 +364,13 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec
setItems(itemsProp) setItems(itemsProp)
} }
// appendQuery appends part to an existing raw // appendQuery appends query values in 'src' to 'dst', returning 'dst'.
// query with ampersand, else just returning part. func appendQuery(dst, src url.Values) url.Values {
func appendQuery(raw, part string) string { if dst == nil {
if raw != "" { return src
return raw + "&" + part
} }
return part for k, vs := range src {
dst[k] = append(dst[k], vs...)
}
return dst
} }

View File

@ -20,14 +20,13 @@ package users
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"net/http" "net/http"
"strconv"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// StatusRepliesGETHandler swagger:operation GET /users/{username}/statuses/{status}/replies s2sRepliesGet // StatusRepliesGETHandler swagger:operation GET /users/{username}/statuses/{status}/replies s2sRepliesGet
@ -120,36 +119,43 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) {
return return
} }
var page bool // Look for supplied 'only_other_accounts' query key.
if pageString := c.Query(PageKey); pageString != "" { onlyOtherAccounts, errWithCode := apiutil.ParseOnlyOtherAccounts(
i, err := strconv.ParseBool(pageString) c.Query(apiutil.OnlyOtherAccountsKey),
if err != nil { true, // default = enabled
err := fmt.Errorf("error parsing %s: %s", PageKey, err) )
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
page = i
}
onlyOtherAccounts := false // Look for given paging query parameters.
onlyOtherAccountsString := c.Query(OnlyOtherAccountsKey) page, errWithCode := paging.ParseIDPage(c,
if onlyOtherAccountsString != "" { 1, // min limit
i, err := strconv.ParseBool(onlyOtherAccountsString) 40, // max limit
if err != nil { 0, // default = disabled
err := fmt.Errorf("error parsing %s: %s", OnlyOtherAccountsKey, err) )
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
onlyOtherAccounts = i
// COMPATIBILITY FIX: 'page=true' enables paging.
if page == nil && c.Query("page") == "true" {
page = new(paging.Page)
page.Max = paging.MaxID("")
page.Min = paging.MinID("")
page.Limit = 20 // default
} }
minID := "" // Fetch serialized status replies response for input status.
minIDString := c.Query(MinIDKey) resp, errWithCode := m.processor.Fedi().StatusRepliesGet(
if minIDString != "" { c.Request.Context(),
minID = minIDString requestedUsername,
} requestedStatusID,
page,
resp, errWithCode := m.processor.Fedi().StatusRepliesGet(c.Request.Context(), requestedUsername, requestedStatusID, page, onlyOtherAccounts, c.Query("only_other_accounts") != "", minID) onlyOtherAccounts,
)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
@ -157,7 +163,8 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) {
b, err := json.Marshal(resp) b, err := json.Marshal(resp)
if err != nil { if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) errWithCode := gtserror.NewErrorInternalError(err)
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }

View File

@ -18,10 +18,10 @@
package users_test package users_test
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "io"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -31,6 +31,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -49,7 +50,7 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
// setup request // setup request
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil) ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies", nil) // the endpoint we're hitting ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies?only_other_accounts=false", nil) // the endpoint we're hitting
ctx.Request.Header.Set("accept", "application/activity+json") ctx.Request.Header.Set("accept", "application/activity+json")
ctx.Request.Header.Set("Signature", signedRequest.SignatureHeader) ctx.Request.Header.Set("Signature", signedRequest.SignatureHeader)
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
@ -76,13 +77,26 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
// check response // check response
suite.EqualValues(http.StatusOK, recorder.Code) suite.EqualValues(http.StatusOK, recorder.Code)
// Read response body.
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"@context":"https://www.w3.org/ns/activitystreams","first":{"id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?page=true","next":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?only_other_accounts=false\u0026page=true","partOf":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"CollectionPage"},"id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"Collection"}`, string(b))
// should be a Collection // Indent JSON
// for readability.
b = indentJSON(b)
// Create JSON string of expected output.
expect := toJSON(map[string]any{
"@context": "https://www.w3.org/ns/activitystreams",
"type": "OrderedCollection",
"id": targetStatus.URI + "/replies?only_other_accounts=false",
"first": targetStatus.URI + "/replies?limit=20&only_other_accounts=false",
"totalItems": 1,
})
assert.Equal(suite.T(), expect, string(b))
m := make(map[string]interface{}) m := make(map[string]interface{})
err = json.Unmarshal(b, &m) err = json.Unmarshal(b, &m)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -90,7 +104,7 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
t, err := streams.ToType(context.Background(), m) t, err := streams.ToType(context.Background(), m)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
_, ok := t.(vocab.ActivityStreamsCollection) _, ok := t.(vocab.ActivityStreamsOrderedCollection)
assert.True(suite.T(), ok) assert.True(suite.T(), ok)
} }
@ -131,14 +145,29 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
// check response // check response
suite.EqualValues(http.StatusOK, recorder.Code) suite.EqualValues(http.StatusOK, recorder.Code)
// Read response body.
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
assert.Equal(suite.T(), `{"@context":"https://www.w3.org/ns/activitystreams","id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?page=true\u0026only_other_accounts=false","items":"http://localhost:8080/users/admin/statuses/01FF25D5Q0DH7CHD57CTRS6WK0","next":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?only_other_accounts=false\u0026page=true\u0026min_id=01FF25D5Q0DH7CHD57CTRS6WK0","partOf":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"CollectionPage"}`, string(b)) // Indent JSON
// for readability.
b = indentJSON(b)
// Create JSON string of expected output.
expect := toJSON(map[string]any{
"@context": "https://www.w3.org/ns/activitystreams",
"type": "OrderedCollectionPage",
"id": targetStatus.URI + "/replies?limit=20&only_other_accounts=false",
"partOf": targetStatus.URI + "/replies?only_other_accounts=false",
"next": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?limit=20&min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false",
"prev": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?limit=20&max_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false",
"orderedItems": "http://localhost:8080/users/admin/statuses/01FF25D5Q0DH7CHD57CTRS6WK0",
"totalItems": 1,
})
assert.Equal(suite.T(), expect, string(b))
// should be a Collection
m := make(map[string]interface{}) m := make(map[string]interface{})
err = json.Unmarshal(b, &m) err = json.Unmarshal(b, &m)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -146,10 +175,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
t, err := streams.ToType(context.Background(), m) t, err := streams.ToType(context.Background(), m)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
page, ok := t.(vocab.ActivityStreamsCollectionPage) page, ok := t.(vocab.ActivityStreamsOrderedCollectionPage)
assert.True(suite.T(), ok) assert.True(suite.T(), ok)
assert.Equal(suite.T(), page.GetActivityStreamsItems().Len(), 1) assert.Equal(suite.T(), page.GetActivityStreamsOrderedItems().Len(), 1)
} }
func (suite *RepliesGetTestSuite) TestGetRepliesLast() { func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
@ -162,7 +191,7 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
// setup request // setup request
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx, _ := testrig.CreateGinTestContext(recorder, nil) ctx, _ := testrig.CreateGinTestContext(recorder, nil)
ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies?only_other_accounts=false&page=true&min_id=01FF25D5Q0DH7CHD57CTRS6WK0", nil) // the endpoint we're hitting ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies?min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false", nil)
ctx.Request.Header.Set("accept", "application/activity+json") ctx.Request.Header.Set("accept", "application/activity+json")
ctx.Request.Header.Set("Signature", signedRequest.SignatureHeader) ctx.Request.Header.Set("Signature", signedRequest.SignatureHeader)
ctx.Request.Header.Set("Date", signedRequest.DateHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader)
@ -189,15 +218,27 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
// check response // check response
suite.EqualValues(http.StatusOK, recorder.Code) suite.EqualValues(http.StatusOK, recorder.Code)
// Read response body.
result := recorder.Result() result := recorder.Result()
defer result.Body.Close() defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
fmt.Println(string(b)) // Indent JSON
assert.Equal(suite.T(), `{"@context":"https://www.w3.org/ns/activitystreams","id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?page=true\u0026only_other_accounts=false\u0026min_id=01FF25D5Q0DH7CHD57CTRS6WK0","items":[],"next":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?only_other_accounts=false\u0026page=true","partOf":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"CollectionPage"}`, string(b)) // for readability.
b = indentJSON(b)
// Create JSON string of expected output.
expect := toJSON(map[string]any{
"@context": "https://www.w3.org/ns/activitystreams",
"type": "OrderedCollectionPage",
"id": targetStatus.URI + "/replies?min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false",
"partOf": targetStatus.URI + "/replies?only_other_accounts=false",
"orderedItems": []any{}, // empty
"totalItems": 1,
})
assert.Equal(suite.T(), expect, string(b))
// should be a Collection
m := make(map[string]interface{}) m := make(map[string]interface{})
err = json.Unmarshal(b, &m) err = json.Unmarshal(b, &m)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -205,12 +246,39 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
t, err := streams.ToType(context.Background(), m) t, err := streams.ToType(context.Background(), m)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
page, ok := t.(vocab.ActivityStreamsCollectionPage) page, ok := t.(vocab.ActivityStreamsOrderedCollectionPage)
assert.True(suite.T(), ok) assert.True(suite.T(), ok)
assert.Equal(suite.T(), page.GetActivityStreamsItems().Len(), 0) assert.Equal(suite.T(), page.GetActivityStreamsOrderedItems().Len(), 0)
} }
func TestRepliesGetTestSuite(t *testing.T) { func TestRepliesGetTestSuite(t *testing.T) {
suite.Run(t, new(RepliesGetTestSuite)) suite.Run(t, new(RepliesGetTestSuite))
} }
// toJSON will return indented JSON serialized form of 'a'.
func toJSON(a any) string {
v, ok := a.(vocab.Type)
if ok {
m, err := ap.Serialize(v)
if err != nil {
panic(err)
}
a = m
}
b, err := json.MarshalIndent(a, "", " ")
if err != nil {
panic(err)
}
return string(b)
}
// indentJSON will return indented JSON from raw provided JSON.
func indentJSON(b []byte) []byte {
var dst bytes.Buffer
err := json.Indent(&dst, b, "", " ")
if err != nil {
panic(err)
}
return dst.Bytes()
}

View File

@ -41,6 +41,10 @@ const (
SinceIDKey = "since_id" SinceIDKey = "since_id"
MinIDKey = "min_id" MinIDKey = "min_id"
/* AP endpoint keys */
OnlyOtherAccountsKey = "only_other_accounts"
/* Search keys */ /* Search keys */
SearchExcludeUnreviewedKey = "exclude_unreviewed" SearchExcludeUnreviewedKey = "exclude_unreviewed"
@ -66,20 +70,6 @@ const (
DomainPermissionImportKey = "import" DomainPermissionImportKey = "import"
) )
// parseError returns gtserror.WithCode set to 400 Bad Request, to indicate
// to the caller that a key was set to a value that could not be parsed.
func parseError(key string, value, defaultValue any, err error) gtserror.WithCode {
err = fmt.Errorf("error parsing key %s with value %s as %T: %w", key, value, defaultValue, err)
return gtserror.NewErrorBadRequest(err, err.Error())
}
// requiredError returns gtserror.WithCode set to 400 Bad Request, to indicate
// to the caller a required key value was not provided, or was empty.
func requiredError(key string) gtserror.WithCode {
err := fmt.Errorf("required key %s was not set or had empty value", key)
return gtserror.NewErrorBadRequest(err, err.Error())
}
/* /*
Parse functions for *OPTIONAL* parameters with default values. Parse functions for *OPTIONAL* parameters with default values.
*/ */
@ -129,6 +119,10 @@ func ParseDomainPermissionImport(value string, defaultValue bool) (bool, gtserro
return parseBool(value, defaultValue, DomainPermissionImportKey) return parseBool(value, defaultValue, DomainPermissionImportKey)
} }
func ParseOnlyOtherAccounts(value string, defaultValue bool) (bool, gtserror.WithCode) {
return parseBool(value, defaultValue, OnlyOtherAccountsKey)
}
/* /*
Parse functions for *REQUIRED* parameters. Parse functions for *REQUIRED* parameters.
*/ */
@ -248,3 +242,17 @@ func parseInt(value string, defaultValue int, max int, min int, key string) (int
return i, nil return i, nil
} }
// parseError returns gtserror.WithCode set to 400 Bad Request, to indicate
// to the caller that a key was set to a value that could not be parsed.
func parseError(key string, value, defaultValue any, err error) gtserror.WithCode {
err = fmt.Errorf("error parsing key %s with value %s as %T: %w", key, value, defaultValue, err)
return gtserror.NewErrorBadRequest(err, err.Error())
}
// requiredError returns gtserror.WithCode set to 400 Bad Request, to indicate
// to the caller a required key value was not provided, or was empty.
func requiredError(key string) gtserror.WithCode {
err := fmt.Errorf("required key %s was not set or had empty value", key)
return gtserror.NewErrorBadRequest(err, err.Error())
}

View File

@ -18,7 +18,6 @@
package bundb package bundb
import ( import (
"container/list"
"context" "context"
"errors" "errors"
"time" "time"
@ -515,16 +514,7 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([
return s.GetStatusesByIDs(ctx, statusIDs) return s.GetStatusesByIDs(ctx, statusIDs)
} }
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) { func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) {
if onlyDirect {
// Only want the direct parent, no further than first level
parent, err := s.GetStatusByID(ctx, status.InReplyToID)
if err != nil {
return nil, err
}
return []*gtsmodel.Status{parent}, nil
}
var parents []*gtsmodel.Status var parents []*gtsmodel.Status
for id := status.InReplyToID; id != ""; { for id := status.InReplyToID; id != ""; {
@ -533,7 +523,7 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status
return nil, err return nil, err
} }
// Append parent to slice // Append parent status to slice
parents = append(parents, parent) parents = append(parents, parent)
// Set the next parent ID // Set the next parent ID
@ -543,67 +533,33 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status
return parents, nil return parents, nil
} }
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) { func (s *statusDB) GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) {
foundStatuses := &list.List{} // Get all replies for the currently set status.
foundStatuses.PushFront(status) replies, err := s.GetStatusReplies(ctx, statusID)
s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) if err != nil {
return nil, err
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
// only append children, not the overall parent status
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
log.Panic(ctx, "found status could not be asserted to *gtsmodel.Status")
} }
if entry.ID != status.ID { // Make estimated preallocation based on direct replies.
children = append(children, entry) children := make([]*gtsmodel.Status, 0, len(replies)*2)
for _, status := range replies {
// Append status to children.
children = append(children, status)
// Further, recursively get all children for this reply.
grandChildren, err := s.GetStatusChildren(ctx, status.ID)
if err != nil {
return nil, err
} }
// Append all sub children after status.
children = append(children, grandChildren...)
} }
return children, nil return children, nil
} }
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
childIDs, err := s.getStatusReplyIDs(ctx, status.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
log.Errorf(ctx, "error getting status %s children: %v", status.ID, err)
return
}
for _, id := range childIDs {
if id <= minID {
continue
}
// Fetch child with ID from database
child, err := s.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting child status %q: %v", id, err)
continue
}
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
log.Panic(ctx, "found status could not be asserted to *gtsmodel.Status")
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// if we're not only looking for direct children of status, then do the same children-finding
// operation for the found child status too.
if !onlyDirect {
s.statusChildren(ctx, child, foundStatuses, false, minID)
}
}
}
func (s *statusDB) GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { func (s *statusDB) GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) {
statusIDs, err := s.getStatusReplyIDs(ctx, statusID) statusIDs, err := s.getStatusReplyIDs(ctx, statusID)
if err != nil { if err != nil {

View File

@ -163,9 +163,21 @@ func (suite *StatusTestSuite) TestGetStatusTwice() {
suite.Less(duration2, duration1) suite.Less(duration2, duration1)
} }
func (suite *StatusTestSuite) TestGetStatusReplies() {
targetStatus := suite.testStatuses["local_account_1_status_1"]
children, err := suite.db.GetStatusReplies(context.Background(), targetStatus.ID)
suite.NoError(err)
suite.Len(children, 2)
for _, c := range children {
suite.Equal(targetStatus.URI, c.InReplyToURI)
suite.Equal(targetStatus.AccountID, c.InReplyToAccountID)
suite.Equal(targetStatus.ID, c.InReplyToID)
}
}
func (suite *StatusTestSuite) TestGetStatusChildren() { func (suite *StatusTestSuite) TestGetStatusChildren() {
targetStatus := suite.testStatuses["local_account_1_status_1"] targetStatus := suite.testStatuses["local_account_1_status_1"]
children, err := suite.db.GetStatusChildren(context.Background(), targetStatus, true, "") children, err := suite.db.GetStatusChildren(context.Background(), targetStatus.ID)
suite.NoError(err) suite.NoError(err)
suite.Len(children, 2) suite.Len(children, 2)
for _, c := range children { for _, c := range children {

View File

@ -18,6 +18,7 @@
package bundb package bundb
import ( import (
"slices"
"strings" "strings"
"github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/cache"
@ -99,7 +100,7 @@ func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page
// order. Depending on the paging requested // order. Depending on the paging requested
// this may be an unexpected order. // this may be an unexpected order.
if page.GetOrder().Ascending() { if page.GetOrder().Ascending() {
ids = paging.Reverse(ids) slices.Reverse(ids)
} }
// Page the resulting IDs. // Page the resulting IDs.

View File

@ -55,7 +55,7 @@ type Status interface {
// GetStatusesUsingEmoji fetches all status models using emoji with given ID stored in their 'emojis' column. // GetStatusesUsingEmoji fetches all status models using emoji with given ID stored in their 'emojis' column.
GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Status, error) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Status, error)
// GetStatusReplies returns the *direct* (i.e. in_reply_to_id column) replies to this status ID. // GetStatusReplies returns the *direct* (i.e. in_reply_to_id column) replies to this status ID, ordered DESC by ID.
GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error)
// CountStatusReplies returns the number of stored *direct* (i.e. in_reply_to_id column) replies to this status ID. // CountStatusReplies returns the number of stored *direct* (i.e. in_reply_to_id column) replies to this status ID.
@ -71,14 +71,10 @@ type Status interface {
IsStatusBoostedBy(ctx context.Context, statusID string, accountID string) (bool, error) IsStatusBoostedBy(ctx context.Context, statusID string, accountID string) (bool, error)
// GetStatusParents gets the parent statuses of a given status. // GetStatusParents gets the parent statuses of a given status.
// GetStatusParents(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error)
// If onlyDirect is true, only the immediate parent will be returned.
GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
// GetStatusChildren gets the child statuses of a given status. // GetStatusChildren gets the child statuses of a given status.
// GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error)
// If onlyDirect is true, only the immediate children will be returned.
GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)

View File

@ -131,3 +131,20 @@ func (b Boundary) Find(in []string) int {
} }
return -1 return -1
} }
// Boundary_FindFunc is functionally equivalent to Boundary{}.Find() but for an arbitrary type with ID.
// Note: this is not a Boundary{} method as Go generics are not supported in method receiver functions.
func Boundary_FindFunc[T any](b Boundary, in []T, get func(T) string) int { //nolint:revive
if get == nil {
panic("nil function")
}
if b.Value == "" {
return -1
}
for i := range in {
if get(in[i]) == b.Value {
return i
}
}
return -1
}

View File

@ -19,9 +19,8 @@ package paging
import ( import (
"net/url" "net/url"
"slices"
"strconv" "strconv"
"golang.org/x/exp/slices"
) )
type Page struct { type Page struct {
@ -117,7 +116,7 @@ func (p *Page) Page(in []string) []string {
// Output slice must // Output slice must
// ALWAYS be descending. // ALWAYS be descending.
in = Reverse(in) slices.Reverse(in)
} }
} else { } else {
// Default sort is descending, // Default sort is descending,
@ -143,6 +142,66 @@ func (p *Page) Page(in []string) []string {
return in return in
} }
// Page_PageFunc is functionally equivalent to Page{}.Page(), but for an arbitrary type with ID.
// Note: this is not a Page{} method as Go generics are not supported in method receiver functions.
func Page_PageFunc[WithID any](p *Page, in []WithID, get func(WithID) string) []WithID { //nolint:revive
if p == nil {
// no paging.
return in
}
if p.order().Ascending() {
// Sort type is ascending, input
// data is assumed to be ascending.
if minIdx := Boundary_FindFunc(p.Min, in, get); minIdx != -1 {
// Reslice skipping up to min.
in = in[minIdx+1:]
}
if maxIdx := Boundary_FindFunc(p.Max, in, get); maxIdx != -1 {
// Reslice stripping past max.
in = in[:maxIdx]
}
if p.Limit > 0 && p.Limit < len(in) {
// Reslice input to limit.
in = in[:p.Limit]
}
if len(in) > 1 {
// Clone input before
// any modifications.
in = slices.Clone(in)
// Output slice must
// ALWAYS be descending.
slices.Reverse(in)
}
} else {
// Default sort is descending,
// catching all cases when NOT
// ascending (even zero value).
if maxIdx := Boundary_FindFunc(p.Max, in, get); maxIdx != -1 {
// Reslice skipping up to max.
in = in[maxIdx+1:]
}
if minIdx := Boundary_FindFunc(p.Min, in, get); minIdx != -1 {
// Reslice stripping past min.
in = in[:minIdx]
}
if p.Limit > 0 && p.Limit < len(in) {
// Reslice input to limit.
in = in[:p.Limit]
}
}
return in
}
// Next creates a new instance for the next returnable page, using // Next creates a new instance for the next returnable page, using
// given max value. This preserves original limit and max key name. // given max value. This preserves original limit and max key name.
func (p *Page) Next(lo, hi string) *Page { func (p *Page) Next(lo, hi string) *Page {
@ -225,21 +284,24 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url.
if queryParams == nil { if queryParams == nil {
// Allocate new query parameters. // Allocate new query parameters.
queryParams = make(url.Values) queryParams = make(url.Values)
} else {
// Before edit clone existing params.
queryParams = cloneQuery(queryParams)
} }
if p.Min.Value != "" { if p.Min.Value != "" {
// A page-minimum query parameter is available. // A page-minimum query parameter is available.
queryParams.Add(p.Min.Name, p.Min.Value) queryParams.Set(p.Min.Name, p.Min.Value)
} }
if p.Max.Value != "" { if p.Max.Value != "" {
// A page-maximum query parameter is available. // A page-maximum query parameter is available.
queryParams.Add(p.Max.Name, p.Max.Value) queryParams.Set(p.Max.Name, p.Max.Value)
} }
if p.Limit > 0 { if p.Limit > 0 {
// A page limit query parameter is available. // A page limit query parameter is available.
queryParams.Add("limit", strconv.Itoa(p.Limit)) queryParams.Set("limit", strconv.Itoa(p.Limit))
} }
// Build URL string. // Build URL string.
@ -250,3 +312,12 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url.
RawQuery: queryParams.Encode(), RawQuery: queryParams.Encode(),
} }
} }
// cloneQuery clones input map of url values.
func cloneQuery(src url.Values) url.Values {
dst := make(url.Values, len(src))
for k, vs := range src {
dst[k] = slices.Clone(vs)
}
return dst
}

View File

@ -19,12 +19,12 @@ package paging_test
import ( import (
"math/rand" "math/rand"
"slices"
"testing" "testing"
"time" "time"
"github.com/oklog/ulid" "github.com/oklog/ulid"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"golang.org/x/exp/slices"
) )
// random reader according to current-time source seed. // random reader according to current-time source seed.
@ -77,9 +77,7 @@ func TestPage(t *testing.T) {
var cases = []Case{ var cases = []Case{
CreateCase("minID and maxID set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("minID and maxID set", func(ids []string) ([]string, *paging.Page, []string) {
// Ensure input slice sorted ascending for min_id // Ensure input slice sorted ascending for min_id
slices.SortFunc(ids, func(a, b string) bool { slices.SortFunc(ids, ascending)
return a > b // i.e. largest at lowest idx
})
// Select random indices in slice. // Select random indices in slice.
minIdx := randRd.Intn(len(ids)) minIdx := randRd.Intn(len(ids))
@ -93,7 +91,7 @@ var cases = []Case{
expect := slices.Clone(ids) expect := slices.Clone(ids)
expect = cutLower(expect, minID) expect = cutLower(expect, minID)
expect = cutUpper(expect, maxID) expect = cutUpper(expect, maxID)
expect = paging.Reverse(expect) slices.Reverse(expect)
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
@ -103,9 +101,7 @@ var cases = []Case{
}), }),
CreateCase("minID, maxID and limit set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("minID, maxID and limit set", func(ids []string) ([]string, *paging.Page, []string) {
// Ensure input slice sorted ascending for min_id // Ensure input slice sorted ascending for min_id
slices.SortFunc(ids, func(a, b string) bool { slices.SortFunc(ids, ascending)
return a > b // i.e. largest at lowest idx
})
// Select random parameters in slice. // Select random parameters in slice.
minIdx := randRd.Intn(len(ids)) minIdx := randRd.Intn(len(ids))
@ -120,7 +116,7 @@ var cases = []Case{
expect := slices.Clone(ids) expect := slices.Clone(ids)
expect = cutLower(expect, minID) expect = cutLower(expect, minID)
expect = cutUpper(expect, maxID) expect = cutUpper(expect, maxID)
expect = paging.Reverse(expect) slices.Reverse(expect)
// Now limit the slice. // Now limit the slice.
if limit < len(expect) { if limit < len(expect) {
@ -136,9 +132,7 @@ var cases = []Case{
}), }),
CreateCase("minID, maxID and too-large limit set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("minID, maxID and too-large limit set", func(ids []string) ([]string, *paging.Page, []string) {
// Ensure input slice sorted ascending for min_id // Ensure input slice sorted ascending for min_id
slices.SortFunc(ids, func(a, b string) bool { slices.SortFunc(ids, ascending)
return a > b // i.e. largest at lowest idx
})
// Select random parameters in slice. // Select random parameters in slice.
minIdx := randRd.Intn(len(ids)) minIdx := randRd.Intn(len(ids))
@ -152,7 +146,7 @@ var cases = []Case{
expect := slices.Clone(ids) expect := slices.Clone(ids)
expect = cutLower(expect, minID) expect = cutLower(expect, minID)
expect = cutUpper(expect, maxID) expect = cutUpper(expect, maxID)
expect = paging.Reverse(expect) slices.Reverse(expect)
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
@ -163,9 +157,7 @@ var cases = []Case{
}), }),
CreateCase("sinceID and maxID set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("sinceID and maxID set", func(ids []string) ([]string, *paging.Page, []string) {
// Ensure input slice sorted descending for since_id // Ensure input slice sorted descending for since_id
slices.SortFunc(ids, func(a, b string) bool { slices.SortFunc(ids, descending)
return a < b // i.e. smallest at lowest idx
})
// Select random indices in slice. // Select random indices in slice.
sinceIdx := randRd.Intn(len(ids)) sinceIdx := randRd.Intn(len(ids))
@ -188,9 +180,7 @@ var cases = []Case{
}), }),
CreateCase("maxID set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("maxID set", func(ids []string) ([]string, *paging.Page, []string) {
// Ensure input slice sorted descending for max_id // Ensure input slice sorted descending for max_id
slices.SortFunc(ids, func(a, b string) bool { slices.SortFunc(ids, descending)
return a < b // i.e. smallest at lowest idx
})
// Select random indices in slice. // Select random indices in slice.
maxIdx := randRd.Intn(len(ids)) maxIdx := randRd.Intn(len(ids))
@ -209,9 +199,7 @@ var cases = []Case{
}), }),
CreateCase("sinceID set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("sinceID set", func(ids []string) ([]string, *paging.Page, []string) {
// Ensure input slice sorted descending for since_id // Ensure input slice sorted descending for since_id
slices.SortFunc(ids, func(a, b string) bool { slices.SortFunc(ids, descending)
return a < b
})
// Select random indices in slice. // Select random indices in slice.
sinceIdx := randRd.Intn(len(ids)) sinceIdx := randRd.Intn(len(ids))
@ -230,9 +218,7 @@ var cases = []Case{
}), }),
CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) {
// Ensure input slice sorted ascending for min_id // Ensure input slice sorted ascending for min_id
slices.SortFunc(ids, func(a, b string) bool { slices.SortFunc(ids, ascending)
return a > b // i.e. largest at lowest idx
})
// Select random indices in slice. // Select random indices in slice.
minIdx := randRd.Intn(len(ids)) minIdx := randRd.Intn(len(ids))
@ -243,7 +229,7 @@ var cases = []Case{
// Create expected output. // Create expected output.
expect := slices.Clone(ids) expect := slices.Clone(ids)
expect = cutLower(expect, minID) expect = cutLower(expect, minID)
expect = paging.Reverse(expect) slices.Reverse(expect)
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
@ -296,3 +282,21 @@ func generateSlice(len int) []string {
} }
return in return in
} }
func ascending(a, b string) int {
if a > b {
return 1
} else if a < b {
return -1
}
return 0
}
func descending(a, b string) int {
if a < b {
return 1
} else if a > b {
return -1
}
return 0
}

View File

@ -1,43 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package paging
// Reverse will reverse the given input slice.
func Reverse(in []string) []string {
var (
// Start at front.
i = 0
// Start at back.
j = len(in) - 1
)
for i < j {
// Swap i,j index values in slice.
in[i], in[j] = in[j], in[i]
// incr + decr,
// looping until
// they meet in
// the middle.
i++
j--
}
return in
}

View File

@ -47,8 +47,15 @@ func (p *Processor) InboxPost(ctx context.Context, w http.ResponseWriter, r *htt
// OutboxGet returns the activitypub representation of a local user's outbox. // OutboxGet returns the activitypub representation of a local user's outbox.
// This contains links to PUBLIC posts made by this user. // This contains links to PUBLIC posts made by this user.
func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, page bool, maxID string, minID string) (interface{}, gtserror.WithCode) { func (p *Processor) OutboxGet(
requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) ctx context.Context,
requestedUser string,
page bool,
maxID string,
minID string,
) (interface{}, gtserror.WithCode) {
// Authenticate the incoming request, getting related user accounts.
_, receiver, errWithCode := p.authenticate(ctx, requestedUser)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -70,7 +77,7 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag
"last": "https://example.org/users/whatever/outbox?min_id=0&page=true" "last": "https://example.org/users/whatever/outbox?min_id=0&page=true"
} }
*/ */
collection, err := p.converter.OutboxToASCollection(ctx, requestedAccount.OutboxURI) collection, err := p.converter.OutboxToASCollection(ctx, receiver.OutboxURI)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -85,15 +92,16 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag
// scenario 2 -- get the requested page // scenario 2 -- get the requested page
// limit pages to 30 entries per page // limit pages to 30 entries per page
publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true) publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, receiver.ID, 30, true, true, maxID, minID, false, true)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
outboxPage, err := p.converter.StatusesToASOutboxPage(ctx, requestedAccount.OutboxURI, maxID, minID, publicStatuses) outboxPage, err := p.converter.StatusesToASOutboxPage(ctx, receiver.OutboxURI, maxID, minID, publicStatuses)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
data, err = ap.Serialize(outboxPage) data, err = ap.Serialize(outboxPage)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -104,21 +112,22 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag
// FollowersGet handles the getting of a fedi/activitypub representation of a user/account's followers, performing appropriate // FollowersGet handles the getting of a fedi/activitypub representation of a user/account's followers, performing appropriate
// authentication before returning a JSON serializable interface to the caller. // authentication before returning a JSON serializable interface to the caller.
func (p *Processor) FollowersGet(ctx context.Context, requestedUsername string, page *paging.Page) (interface{}, gtserror.WithCode) { func (p *Processor) FollowersGet(ctx context.Context, requestedUser string, page *paging.Page) (interface{}, gtserror.WithCode) {
requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) // Authenticate the incoming request, getting related user accounts.
_, receiver, errWithCode := p.authenticate(ctx, requestedUser)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
// Parse the collection ID object from account's followers URI. // Parse the collection ID object from account's followers URI.
collectionID, err := url.Parse(requestedAccount.FollowersURI) collectionID, err := url.Parse(receiver.FollowersURI)
if err != nil { if err != nil {
err := gtserror.Newf("error parsing account followers uri %s: %w", requestedAccount.FollowersURI, err) err := gtserror.Newf("error parsing account followers uri %s: %w", receiver.FollowersURI, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// Calculate total number of followers available for account. // Calculate total number of followers available for account.
total, err := p.state.DB.CountAccountFollowers(ctx, requestedAccount.ID) total, err := p.state.DB.CountAccountFollowers(ctx, receiver.ID)
if err != nil { if err != nil {
err := gtserror.Newf("error counting followers: %w", err) err := gtserror.Newf("error counting followers: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -126,30 +135,36 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUsername string,
var obj vocab.Type var obj vocab.Type
// Start building AS collection params. // Start the AS collection params.
var params ap.CollectionParams var params ap.CollectionParams
params.ID = collectionID params.ID = collectionID
params.Total = total params.Total = total
if page == nil { if page == nil {
// i.e. paging disabled, the simplest case. // i.e. paging disabled, return collection
// // that links to first page (i.e. path below).
// Just build collection object from params. params.Query = make(url.Values, 1)
params.Query.Set("limit", "40") // enables paging
obj = ap.NewASOrderedCollection(params) obj = ap.NewASOrderedCollection(params)
} else { } else {
// i.e. paging enabled // i.e. paging enabled
// Get the request page of full follower objects with attached accounts. // Get the request page of full follower objects with attached accounts.
followers, err := p.state.DB.GetAccountFollowers(ctx, requestedAccount.ID, page) followers, err := p.state.DB.GetAccountFollowers(ctx, receiver.ID, page)
if err != nil { if err != nil {
err := gtserror.Newf("error getting followers: %w", err) err := gtserror.Newf("error getting followers: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// page ID values.
var lo, hi string
if len(followers) > 0 {
// Get the lowest and highest // Get the lowest and highest
// ID values, used for paging. // ID values, used for paging.
lo := followers[len(followers)-1].ID lo = followers[len(followers)-1].ID
hi := followers[0].ID hi = followers[0].ID
}
// Start building AS collection page params. // Start building AS collection page params.
var pageParams ap.CollectionPageParams var pageParams ap.CollectionPageParams
@ -196,21 +211,22 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUsername string,
// FollowingGet handles the getting of a fedi/activitypub representation of a user/account's following, performing appropriate // FollowingGet handles the getting of a fedi/activitypub representation of a user/account's following, performing appropriate
// authentication before returning a JSON serializable interface to the caller. // authentication before returning a JSON serializable interface to the caller.
func (p *Processor) FollowingGet(ctx context.Context, requestedUsername string, page *paging.Page) (interface{}, gtserror.WithCode) { func (p *Processor) FollowingGet(ctx context.Context, requestedUser string, page *paging.Page) (interface{}, gtserror.WithCode) {
requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) // Authenticate the incoming request, getting related user accounts.
_, receiver, errWithCode := p.authenticate(ctx, requestedUser)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
// Parse the collection ID object from account's following URI. // Parse collection ID from account's following URI.
collectionID, err := url.Parse(requestedAccount.FollowingURI) collectionID, err := url.Parse(receiver.FollowingURI)
if err != nil { if err != nil {
err := gtserror.Newf("error parsing account following uri %s: %w", requestedAccount.FollowingURI, err) err := gtserror.Newf("error parsing account following uri %s: %w", receiver.FollowingURI, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// Calculate total number of following available for account. // Calculate total number of following available for account.
total, err := p.state.DB.CountAccountFollows(ctx, requestedAccount.ID) total, err := p.state.DB.CountAccountFollows(ctx, receiver.ID)
if err != nil { if err != nil {
err := gtserror.Newf("error counting follows: %w", err) err := gtserror.Newf("error counting follows: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -218,32 +234,38 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUsername string,
var obj vocab.Type var obj vocab.Type
// Start building AS collection params. // Start AS collection params.
var params ap.CollectionParams var params ap.CollectionParams
params.ID = collectionID params.ID = collectionID
params.Total = total params.Total = total
if page == nil { if page == nil {
// i.e. paging disabled, the simplest case. // i.e. paging disabled, return collection
// // that links to first page (i.e. path below).
// Just build collection object from params. params.Query = make(url.Values, 1)
params.Query.Set("limit", "40") // enables paging
obj = ap.NewASOrderedCollection(params) obj = ap.NewASOrderedCollection(params)
} else { } else {
// i.e. paging enabled // i.e. paging enabled
// Get the request page of full follower objects with attached accounts. // Get the request page of full follower objects with attached accounts.
follows, err := p.state.DB.GetAccountFollows(ctx, requestedAccount.ID, page) follows, err := p.state.DB.GetAccountFollows(ctx, receiver.ID, page)
if err != nil { if err != nil {
err := gtserror.Newf("error getting follows: %w", err) err := gtserror.Newf("error getting follows: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// page ID values.
var lo, hi string
if len(follows) > 0 {
// Get the lowest and highest // Get the lowest and highest
// ID values, used for paging. // ID values, used for paging.
lo := follows[len(follows)-1].ID lo = follows[len(follows)-1].ID
hi := follows[0].ID hi = follows[0].ID
}
// Start building AS collection page params. // Start AS collection page params.
var pageParams ap.CollectionPageParams var pageParams ap.CollectionPageParams
pageParams.CollectionParams = params pageParams.CollectionParams = params
@ -288,20 +310,21 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUsername string,
// FeaturedCollectionGet returns an ordered collection of the requested username's Pinned posts. // FeaturedCollectionGet returns an ordered collection of the requested username's Pinned posts.
// The returned collection have an `items` property which contains an ordered list of status URIs. // The returned collection have an `items` property which contains an ordered list of status URIs.
func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUsername string) (interface{}, gtserror.WithCode) { func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUser string) (interface{}, gtserror.WithCode) {
requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) // Authenticate the incoming request, getting related user accounts.
_, receiver, errWithCode := p.authenticate(ctx, requestedUser)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, requestedAccount.ID) statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, receiver.ID)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
} }
collection, err := p.converter.StatusesToASFeaturedCollection(ctx, requestedAccount.FeaturedCollectionURI, statuses) collection, err := p.converter.StatusesToASFeaturedCollection(ctx, receiver.FeaturedCollectionURI, statuses)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View File

@ -20,7 +20,6 @@ package fedi
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -28,17 +27,17 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
) )
func (p *Processor) authenticate(ctx context.Context, requestedUsername string) ( func (p *Processor) authenticate(ctx context.Context, requestedUser string) (
*gtsmodel.Account, // requestedAccount *gtsmodel.Account, // requester: i.e. user making the request
*gtsmodel.Account, // requestingAccount *gtsmodel.Account, // receiver: i.e. the receiving inbox user
gtserror.WithCode, gtserror.WithCode,
) { ) {
// Get LOCAL account with the requested username. // First get the requested (receiving) LOCAL account with username from database.
requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "") receiver, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUser, "")
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
// Real db error. // Real db error.
err = gtserror.Newf("db error getting account %s: %w", requestedUsername, err) err = gtserror.Newf("db error getting account %s: %w", requestedUser, err)
return nil, nil, gtserror.NewErrorInternalError(err) return nil, nil, gtserror.NewErrorInternalError(err)
} }
@ -46,41 +45,43 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)
return nil, nil, gtserror.NewErrorNotFound(err) return nil, nil, gtserror.NewErrorNotFound(err)
} }
var requester *gtsmodel.Account
// Ensure request signed, and use signature URI to // Ensure request signed, and use signature URI to
// get requesting account, dereferencing if necessary. // get requesting account, dereferencing if necessary.
pubKeyAuth, errWithCode := p.federator.AuthenticateFederatedRequest(ctx, requestedUsername) pubKeyAuth, errWithCode := p.federator.AuthenticateFederatedRequest(ctx, requestedUser)
if errWithCode != nil { if errWithCode != nil {
return nil, nil, errWithCode return nil, nil, errWithCode
} }
requestingAccount, _, err := p.federator.GetAccountByURI( if requester = pubKeyAuth.Owner; requester == nil {
requester, _, err = p.federator.GetAccountByURI(
gtscontext.SetFastFail(ctx), gtscontext.SetFastFail(ctx),
requestedUsername, requestedUser,
pubKeyAuth.OwnerURI, pubKeyAuth.OwnerURI,
) )
if err != nil { if err != nil {
err = gtserror.Newf("error getting account %s: %w", pubKeyAuth.OwnerURI, err) err = gtserror.Newf("error getting account %s: %w", pubKeyAuth.OwnerURI, err)
return nil, nil, gtserror.NewErrorUnauthorized(err) return nil, nil, gtserror.NewErrorUnauthorized(err)
} }
}
if !requestingAccount.SuspendedAt.IsZero() { if !requester.SuspendedAt.IsZero() {
// Account was marked as suspended by a // Account was marked as suspended by a
// local admin action. Stop request early. // local admin action. Stop request early.
err = fmt.Errorf("account %s marked as suspended", requestingAccount.ID) const text = "requesting account is suspended"
return nil, nil, gtserror.NewErrorForbidden(err) return nil, nil, gtserror.NewErrorForbidden(errors.New(text))
} }
// Ensure no block exists between requester + requested. // Ensure no block exists between requester + requested.
blocked, err := p.state.DB.IsEitherBlocked(ctx, requestedAccount.ID, requestingAccount.ID) blocked, err := p.state.DB.IsEitherBlocked(ctx, receiver.ID, requester.ID)
if err != nil { if err != nil {
err = gtserror.Newf("db error getting checking block: %w", err) err = gtserror.Newf("db error getting checking block: %w", err)
return nil, nil, gtserror.NewErrorInternalError(err) return nil, nil, gtserror.NewErrorInternalError(err)
} } else if blocked {
err = gtserror.Newf("block exists between accounts %s and %s", requester.ID, receiver.ID)
if blocked {
err = fmt.Errorf("block exists between accounts %s and %s", requestedAccount.ID, requestingAccount.ID)
return nil, nil, gtserror.NewErrorForbidden(err) return nil, nil, gtserror.NewErrorForbidden(err)
} }
return requestedAccount, requestingAccount, nil return requester, receiver, nil
} }

View File

@ -19,21 +19,33 @@ package fedi
import ( import (
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/internal/visibility"
) )
type Processor struct { type Processor struct {
// embed common logic
c *common.Processor
state *state.State state *state.State
federator *federation.Federator federator *federation.Federator
converter *typeutils.Converter converter *typeutils.Converter
filter *visibility.Filter filter *visibility.Filter
} }
// New returns a new fedi processor. // New returns a
func New(state *state.State, converter *typeutils.Converter, federator *federation.Federator, filter *visibility.Filter) Processor { // new fedi processor.
func New(
state *state.State,
common *common.Processor,
converter *typeutils.Converter,
federator *federation.Federator,
filter *visibility.Filter,
) Processor {
return Processor{ return Processor{
c: common,
state: state, state: state,
federator: federator, federator: federator,
converter: converter, converter: converter,

View File

@ -19,161 +19,192 @@ package fedi
import ( import (
"context" "context"
"fmt" "errors"
"net/url" "net/url"
"slices"
"strconv"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// StatusGet handles the getting of a fedi/activitypub representation of a local status. // StatusGet handles the getting of a fedi/activitypub representation of a local status.
// It performs appropriate authentication before returning a JSON serializable interface. // It performs appropriate authentication before returning a JSON serializable interface.
func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, requestedStatusID string) (interface{}, gtserror.WithCode) { func (p *Processor) StatusGet(ctx context.Context, requestedUser string, statusID string) (interface{}, gtserror.WithCode) {
// Authenticate using http signature. // Authenticate using http signature.
requestedAccount, requestingAccount, errWithCode := p.authenticate(ctx, requestedUsername) // Authenticate the incoming request, getting related user accounts.
requester, receiver, errWithCode := p.authenticate(ctx, requestedUser)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID) status, err := p.state.DB.GetStatusByID(ctx, statusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }
if status.AccountID != requestedAccount.ID { if status.AccountID != receiver.ID {
err := fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID) const text = "status does not belong to receiving account"
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(errors.New(text))
} }
visible, err := p.filter.StatusVisible(ctx, requestingAccount, status) visible, err := p.filter.StatusVisible(ctx, requester, status)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
if !visible { if !visible {
err := fmt.Errorf("status with id %s not visible to user with id %s", status.ID, requestingAccount.ID) const text = "status not vising to requesting account"
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(errors.New(text))
} }
statusable, err := p.converter.StatusToAS(ctx, status) statusable, err := p.converter.StatusToAS(ctx, status)
if err != nil { if err != nil {
err := gtserror.Newf("error converting status: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
data, err := ap.Serialize(statusable) data, err := ap.Serialize(statusable)
if err != nil { if err != nil {
err := gtserror.Newf("error serializing status: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
return data, nil return data, nil
} }
// GetStatus handles the getting of a fedi/activitypub representation of replies to a status, performing appropriate // GetStatus handles the getting of a fedi/activitypub representation of replies to a status,
// authentication before returning a JSON serializable interface to the caller. // performing appropriate authentication before returning a JSON serializable interface to the caller.
func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername string, requestedStatusID string, page bool, onlyOtherAccounts bool, onlyOtherAccountsSet bool, minID string) (interface{}, gtserror.WithCode) { func (p *Processor) StatusRepliesGet(
requestedAccount, requestingAccount, errWithCode := p.authenticate(ctx, requestedUsername) ctx context.Context,
requestedUser string,
statusID string,
page *paging.Page,
onlyOtherAccounts bool,
) (interface{}, gtserror.WithCode) {
// Authenticate the incoming request, getting related user accounts.
requester, receiver, errWithCode := p.authenticate(ctx, requestedUser)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID) // Get target status and ensure visible to requester.
if err != nil { status, errWithCode := p.c.GetVisibleTargetStatus(ctx,
return nil, gtserror.NewErrorNotFound(err) requester,
} statusID,
)
if status.AccountID != requestedAccount.ID { if errWithCode != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID)) return nil, errWithCode
} }
visible, err := p.filter.StatusVisible(ctx, requestedAccount, status) // Ensure status is by receiving account.
if err != nil { if status.AccountID != receiver.ID {
return nil, gtserror.NewErrorInternalError(err) const text = "status does not belong to receiving account"
} return nil, gtserror.NewErrorNotFound(errors.New(text))
if !visible { }
return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s not visible to user with id %s", status.ID, requestingAccount.ID))
} // Parse replies collection ID from status' URI with onlyOtherAccounts param.
onlyOtherAccStr := "only_other_accounts=" + strconv.FormatBool(onlyOtherAccounts)
var data map[string]interface{} collectionID, err := url.Parse(status.URI + "/replies?" + onlyOtherAccStr)
// now there are three scenarios:
// 1. we're asked for the whole collection and not a page -- we can just return the collection, with no items, but a link to 'first' page.
// 2. we're asked for a page but only_other_accounts has not been set in the query -- so we should just return the first page of the collection, with no items.
// 3. we're asked for a page, and only_other_accounts has been set, and min_id has optionally been set -- so we need to return some actual items!
switch {
case !page:
// scenario 1
// get the collection
collection, err := p.converter.StatusToASRepliesCollection(ctx, status, onlyOtherAccounts)
if err != nil { if err != nil {
err := gtserror.Newf("error parsing status uri %s: %w", status.URI, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
data, err = ap.Serialize(collection) // Get *all* available replies for status (i.e. without paging).
if err != nil { replies, err := p.state.DB.GetStatusReplies(ctx, status.ID)
return nil, gtserror.NewErrorInternalError(err)
}
case page && !onlyOtherAccountsSet:
// scenario 2
// get the collection
collection, err := p.converter.StatusToASRepliesCollection(ctx, status, onlyOtherAccounts)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// but only return the first page
data, err = ap.Serialize(collection.GetActivityStreamsFirst().GetActivityStreamsCollectionPage())
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
default:
// scenario 3
// get immediate children
replies, err := p.state.DB.GetStatusChildren(ctx, status, true, minID)
if err != nil { if err != nil {
err := gtserror.Newf("error getting status replies: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// filter children and extract URIs if onlyOtherAccounts {
replyURIs := map[string]*url.URL{} // If 'onlyOtherAccounts' is set, drop all by original status author.
for _, r := range replies { replies = slices.DeleteFunc(replies, func(reply *gtsmodel.Status) bool {
// only show public or unlocked statuses as replies return reply.AccountID == status.AccountID
if r.Visibility != gtsmodel.VisibilityPublic && r.Visibility != gtsmodel.VisibilityUnlocked { })
continue
} }
// respect onlyOtherAccounts parameter // Reslice replies dropping all those invisible to requester.
if onlyOtherAccounts && r.AccountID == requestedAccount.ID { replies, err = p.filter.StatusesVisible(ctx, requester, replies)
continue
}
// only show replies that the status owner can see
visibleToStatusOwner, err := p.filter.StatusVisible(ctx, requestedAccount, r)
if err != nil || !visibleToStatusOwner {
continue
}
// only show replies that the requester can see
visibleToRequester, err := p.filter.StatusVisible(ctx, requestingAccount, r)
if err != nil || !visibleToRequester {
continue
}
rURI, err := url.Parse(r.URI)
if err != nil {
continue
}
replyURIs[r.ID] = rURI
}
repliesPage, err := p.converter.StatusURIsToASRepliesPage(ctx, status, onlyOtherAccounts, minID, replyURIs)
if err != nil { if err != nil {
err := gtserror.Newf("error filtering status replies: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
data, err = ap.Serialize(repliesPage)
if err != nil { var obj vocab.Type
return nil, gtserror.NewErrorInternalError(err)
// Start AS collection params.
var params ap.CollectionParams
params.ID = collectionID
params.Total = len(replies)
if page == nil {
// i.e. paging disabled, return collection
// that links to first page (i.e. path below).
params.Query = make(url.Values, 1)
params.Query.Set("limit", "20") // enables paging
obj = ap.NewASOrderedCollection(params)
} else {
// i.e. paging enabled
// Page and reslice the replies according to given parameters.
replies = paging.Page_PageFunc(page, replies, func(reply *gtsmodel.Status) string {
return reply.ID
})
// page ID values.
var lo, hi string
if len(replies) > 0 {
// Get the lowest and highest
// ID values, used for paging.
lo = replies[len(replies)-1].ID
hi = replies[0].ID
} }
// Start AS collection page params.
var pageParams ap.CollectionPageParams
pageParams.CollectionParams = params
// Current page details.
pageParams.Current = page
pageParams.Count = len(replies)
// Set linked next/prev parameters.
pageParams.Next = page.Next(lo, hi)
pageParams.Prev = page.Prev(lo, hi)
// Set the collection item property builder function.
pageParams.Append = func(i int, itemsProp ap.ItemsPropertyBuilder) {
// Get follower URI at index.
status := replies[i]
uri := status.URI
// Parse URL object from URI.
iri, err := url.Parse(uri)
if err != nil {
log.Errorf(ctx, "error parsing status uri %s: %v", uri, err)
return
}
// Add to item property.
itemsProp.AppendIRI(iri)
}
// Build AS collection page object from params.
obj = ap.NewASOrderedCollectionPage(pageParams)
}
// Serialized the prepared object.
data, err := ap.Serialize(obj)
if err != nil {
err := gtserror.Newf("error serializing: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} }
return data, nil return data, nil

View File

@ -156,23 +156,23 @@ func NewProcessor(
// //
// Start with sub processors that will // Start with sub processors that will
// be required by the workers processor. // be required by the workers processor.
commonProcessor := common.New(state, converter, federator, filter) common := common.New(state, converter, federator, filter)
processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc)
processor.media = media.New(state, converter, mediaManager, federator.TransportController()) processor.media = media.New(state, converter, mediaManager, federator.TransportController())
processor.stream = stream.New(state, oauthServer) processor.stream = stream.New(state, oauthServer)
// Instantiate the rest of the sub // Instantiate the rest of the sub
// processors + pin them to this struct. // processors + pin them to this struct.
processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc)
processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender) processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender)
processor.fedi = fedi.New(state, converter, federator, filter) processor.fedi = fedi.New(state, &common, converter, federator, filter)
processor.list = list.New(state, converter) processor.list = list.New(state, converter)
processor.markers = markers.New(state, converter) processor.markers = markers.New(state, converter)
processor.polls = polls.New(&commonProcessor, state, converter) processor.polls = polls.New(&common, state, converter)
processor.report = report.New(state, converter) processor.report = report.New(state, converter)
processor.timeline = timeline.New(state, converter, filter) processor.timeline = timeline.New(state, converter, filter)
processor.search = search.New(state, federator, converter, filter) processor.search = search.New(state, federator, converter, filter)
processor.status = status.New(state, &commonProcessor, &processor.polls, federator, converter, filter, parseMentionFunc) processor.status = status.New(state, &common, &processor.polls, federator, converter, filter, parseMentionFunc)
processor.user = user.New(state, emailSender) processor.user = user.New(state, emailSender)
// Workers processor handles asynchronous // Workers processor handles asynchronous

View File

@ -67,7 +67,7 @@ func (p *Processor) contextGet(
Descendants: []apimodel.Status{}, Descendants: []apimodel.Status{},
} }
parents, err := p.state.DB.GetStatusParents(ctx, targetStatus, false) parents, err := p.state.DB.GetStatusParents(ctx, targetStatus)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -85,7 +85,7 @@ func (p *Processor) contextGet(
return context.Ancestors[i].ID < context.Ancestors[j].ID return context.Ancestors[i].ID < context.Ancestors[j].ID
}) })
children, err := p.state.DB.GetStatusChildren(ctx, targetStatus, false, "") children, err := p.state.DB.GetStatusChildren(ctx, targetStatus.ID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View File

@ -33,3 +33,11 @@ func EqualPtrs[T comparable](t1, t2 *T) bool {
func Ptr[T any](t T) *T { func Ptr[T any](t T) *T {
return &t return &t
} }
// PtrValueOr returns either value of ptr, or default.
func PtrValueOr[T any](t *T, _default T) T {
if t != nil {
return *t
}
return _default
}

View File

@ -20,7 +20,6 @@ package visibility
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/cache"
@ -219,7 +218,7 @@ func (f *Filter) isVisibleConversation(ctx context.Context, owner *gtsmodel.Acco
status.AccountID, status.AccountID,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("error checking follow %s->%s: %w", owner.ID, status.AccountID, err) return false, gtserror.Newf("error checking follow %s->%s: %w", owner.ID, status.AccountID, err)
} }
if !followAuthor { if !followAuthor {
@ -236,7 +235,7 @@ func (f *Filter) isVisibleConversation(ctx context.Context, owner *gtsmodel.Acco
mention.TargetAccountID, mention.TargetAccountID,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("error checking mention follow %s->%s: %w", owner.ID, mention.TargetAccountID, err) return false, gtserror.Newf("error checking mention follow %s->%s: %w", owner.ID, mention.TargetAccountID, err)
} }
if follow { if follow {

View File

@ -19,11 +19,11 @@ package visibility
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
) )
@ -105,7 +105,7 @@ func (f *Filter) isStatusPublicTimelineable(ctx context.Context, requester *gtsm
parentID, parentID,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("isStatusPublicTimelineable: error getting status parent %s: %w", parentID, err) return false, gtserror.Newf("error getting status parent %s: %w", parentID, err)
} }
if parent.AccountID != status.AccountID { if parent.AccountID != status.AccountID {

View File

@ -19,32 +19,26 @@ package visibility
import ( import (
"context" "context"
"fmt" "slices"
"github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
) )
// StatusesVisible calls StatusVisible for each status in the statuses slice, and returns a slice of only statuses which are visible to the requester. // StatusesVisible calls StatusVisible for each status in the statuses slice, and returns a slice of only statuses which are visible to the requester.
func (f *Filter) StatusesVisible(ctx context.Context, requester *gtsmodel.Account, statuses []*gtsmodel.Status) ([]*gtsmodel.Status, error) { func (f *Filter) StatusesVisible(ctx context.Context, requester *gtsmodel.Account, statuses []*gtsmodel.Status) ([]*gtsmodel.Status, error) {
// Preallocate slice of maximum possible length. var errs gtserror.MultiError
filtered := make([]*gtsmodel.Status, 0, len(statuses)) filtered := slices.DeleteFunc(statuses, func(status *gtsmodel.Status) bool {
for _, status := range statuses {
// Check whether status is visible to requester.
visible, err := f.StatusVisible(ctx, requester, status) visible, err := f.StatusVisible(ctx, requester, status)
if err != nil { if err != nil {
return nil, err errs.Append(err)
return true
} }
return !visible
if visible { })
// Add filtered status to ret slice. return filtered, errs.Combine()
filtered = append(filtered, status)
}
}
return filtered, nil
} }
// StatusVisible will check if given status is visible to requester, accounting for requester with no auth (i.e is nil), suspensions, disabled local users, account blocks and status privacy. // StatusVisible will check if given status is visible to requester, accounting for requester with no auth (i.e is nil), suspensions, disabled local users, account blocks and status privacy.
@ -85,13 +79,13 @@ func (f *Filter) StatusVisible(ctx context.Context, requester *gtsmodel.Account,
func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) { func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// Ensure that status is fully populated for further processing. // Ensure that status is fully populated for further processing.
if err := f.state.DB.PopulateStatus(ctx, status); err != nil { if err := f.state.DB.PopulateStatus(ctx, status); err != nil {
return false, fmt.Errorf("isStatusVisible: error populating status %s: %w", status.ID, err) return false, gtserror.Newf("error populating status %s: %w", status.ID, err)
} }
// Check whether status accounts are visible to the requester. // Check whether status accounts are visible to the requester.
visible, err := f.areStatusAccountsVisible(ctx, requester, status) visible, err := f.areStatusAccountsVisible(ctx, requester, status)
if err != nil { if err != nil {
return false, fmt.Errorf("isStatusVisible: error checking status %s account visibility: %w", status.ID, err) return false, gtserror.Newf("error checking status %s account visibility: %w", status.ID, err)
} else if !visible { } else if !visible {
return false, nil return false, nil
} }
@ -127,7 +121,7 @@ func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Accoun
// Boosted status needs its mentions populating, fetch these from database. // Boosted status needs its mentions populating, fetch these from database.
status.BoostOf.Mentions, err = f.state.DB.GetMentions(ctx, status.BoostOf.MentionIDs) status.BoostOf.Mentions, err = f.state.DB.GetMentions(ctx, status.BoostOf.MentionIDs)
if err != nil { if err != nil {
return false, fmt.Errorf("isStatusVisible: error populating boosted status %s mentions: %w", status.BoostOfID, err) return false, gtserror.Newf("error populating boosted status %s mentions: %w", status.BoostOfID, err)
} }
} }
@ -145,7 +139,7 @@ func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Accoun
status.AccountID, status.AccountID,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("isStatusVisible: error checking follow %s->%s: %w", requester.ID, status.AccountID, err) return false, gtserror.Newf("error checking follow %s->%s: %w", requester.ID, status.AccountID, err)
} }
if !follows { if !follows {
@ -162,7 +156,7 @@ func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Accoun
status.AccountID, status.AccountID,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("isStatusVisible: error checking mutual follow %s<->%s: %w", requester.ID, status.AccountID, err) return false, gtserror.Newf("error checking mutual follow %s<->%s: %w", requester.ID, status.AccountID, err)
} }
if !mutuals { if !mutuals {
@ -187,7 +181,7 @@ func (f *Filter) areStatusAccountsVisible(ctx context.Context, requester *gtsmod
// Check whether status author's account is visible to requester. // Check whether status author's account is visible to requester.
visible, err := f.AccountVisible(ctx, requester, status.Account) visible, err := f.AccountVisible(ctx, requester, status.Account)
if err != nil { if err != nil {
return false, fmt.Errorf("error checking status author visibility: %w", err) return false, gtserror.Newf("error checking status author visibility: %w", err)
} }
if !visible { if !visible {
@ -206,7 +200,7 @@ func (f *Filter) areStatusAccountsVisible(ctx context.Context, requester *gtsmod
// Check whether boosted status author's account is visible to requester. // Check whether boosted status author's account is visible to requester.
visible, err := f.AccountVisible(ctx, requester, status.BoostOfAccount) visible, err := f.AccountVisible(ctx, requester, status.BoostOfAccount)
if err != nil { if err != nil {
return false, fmt.Errorf("error checking boosted author visibility: %w", err) return false, gtserror.Newf("error checking boosted author visibility: %w", err)
} }
if !visible { if !visible {

View File

@ -3163,7 +3163,7 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin
DateHeader: date, DateHeader: date,
} }
target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies") target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies?only_other_accounts=false")
sig, digest, date = GetSignatureForDereference(accounts["remote_account_1"].PublicKeyURI, accounts["remote_account_1"].PrivateKey, target) sig, digest, date = GetSignatureForDereference(accounts["remote_account_1"].PublicKeyURI, accounts["remote_account_1"].PrivateKey, target)
fossSatanDereferenceLocalAccount1Status1Replies := ActivityWithSignature{ fossSatanDereferenceLocalAccount1Status1Replies := ActivityWithSignature{
SignatureHeader: sig, SignatureHeader: sig,
@ -3179,7 +3179,7 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin
DateHeader: date, DateHeader: date,
} }
target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies?only_other_accounts=false&page=true&min_id=01FF25D5Q0DH7CHD57CTRS6WK0") target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies?min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false")
sig, digest, date = GetSignatureForDereference(accounts["remote_account_1"].PublicKeyURI, accounts["remote_account_1"].PrivateKey, target) sig, digest, date = GetSignatureForDereference(accounts["remote_account_1"].PublicKeyURI, accounts["remote_account_1"].PrivateKey, target)
fossSatanDereferenceLocalAccount1Status1RepliesLast := ActivityWithSignature{ fossSatanDereferenceLocalAccount1Status1RepliesLast := ActivityWithSignature{
SignatureHeader: sig, SignatureHeader: sig,