From c9452f32f38b9ac1fb96a834202fd3e2f25897a1 Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Wed, 31 Jan 2024 13:31:53 +0000 Subject: [PATCH] [bugfix] fix possible infinite loops in media / emoji cleanup (#2590) * update media / emoji cleaner funcs to use new paging package, check for same returned maxID * fix other calls of getattachments and getmojis not using paging * use alternative order-by function --- cmd/gotosocial/action/admin/media/list.go | 66 ++++++++++---------- internal/cleaner/emoji.go | 63 +++++++++++++------ internal/cleaner/media.go | 37 +++++++---- internal/db/bundb/emoji.go | 15 +++-- internal/db/bundb/media.go | 11 +++- internal/db/emoji.go | 5 +- internal/db/media.go | 5 +- internal/util/slices.go | 75 +++++++++-------------- 8 files changed, 158 insertions(+), 119 deletions(-) diff --git a/cmd/gotosocial/action/admin/media/list.go b/cmd/gotosocial/action/admin/media/list.go index 0a2e60ede..ed10c967a 100644 --- a/cmd/gotosocial/action/admin/media/list.go +++ b/cmd/gotosocial/action/admin/media/list.go @@ -31,14 +31,14 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" ) type list struct { dbService db.DB state *state.State - maxID string - limit int + page paging.Page localOnly bool remoteOnly bool out *bufio.Writer @@ -47,31 +47,32 @@ type list struct { // Get a list of attachment using a custom filter func (l *list) GetAllAttachmentPaths(ctx context.Context, filter func(*gtsmodel.MediaAttachment) string) ([]string, error) { res := make([]string, 0, 100) + for { - attachments, err := l.dbService.GetAttachments(ctx, l.maxID, l.limit) - if err != nil { + // Get the next page of media attachments up to max ID. + attachments, err := l.dbService.GetAttachments(ctx, &l.page) + if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, fmt.Errorf("failed to retrieve media metadata from database: %w", err) } + // Get current max ID. + maxID := l.page.Max.Value + + // If no attachments or the same group is returned, we reached the end. + if len(attachments) == 0 || maxID == attachments[len(attachments)-1].ID { + break + } + + // Use last ID as the next 'maxID' value. + maxID = attachments[len(attachments)-1].ID + l.page.Max = paging.MaxID(maxID) + for _, a := range attachments { v := filter(a) if v != "" { res = append(res, v) } } - - // If we got less results than our limit, we've reached the - // last page to retrieve and we can break the loop. If the - // last batch happens to contain exactly the same amount of - // items as the limit we'll end up doing one extra query. - if len(attachments) < l.limit { - break - } - - // Grab the last ID from the batch and set it as the maxID - // that'll be used in the next iteration so we don't get items - // we've already seen. - l.maxID = attachments[len(attachments)-1].ID } return res, nil } @@ -80,30 +81,30 @@ func (l *list) GetAllAttachmentPaths(ctx context.Context, filter func(*gtsmodel. func (l *list) GetAllEmojisPaths(ctx context.Context, filter func(*gtsmodel.Emoji) string) ([]string, error) { res := make([]string, 0, 100) for { - attachments, err := l.dbService.GetEmojis(ctx, l.maxID, l.limit) + // Get the next page of emoji media up to max ID. + attachments, err := l.dbService.GetEmojis(ctx, &l.page) if err != nil { return nil, fmt.Errorf("failed to retrieve media metadata from database: %w", err) } + // Get current max ID. + maxID := l.page.Max.Value + + // If no attachments or the same group is returned, we reached the end. + if len(attachments) == 0 || maxID == attachments[len(attachments)-1].ID { + break + } + + // Use last ID as the next 'maxID' value. + maxID = attachments[len(attachments)-1].ID + l.page.Max = paging.MaxID(maxID) + for _, a := range attachments { v := filter(a) if v != "" { res = append(res, v) } } - - // If we got less results than our limit, we've reached the - // last page to retrieve and we can break the loop. If the - // last batch happens to contain exactly the same amount of - // items as the limit we'll end up doing one extra query. - if len(attachments) < l.limit { - break - } - - // Grab the last ID from the batch and set it as the maxID - // that'll be used in the next iteration so we don't get items - // we've already seen. - l.maxID = attachments[len(attachments)-1].ID } return res, nil } @@ -137,8 +138,7 @@ func setupList(ctx context.Context) (*list, error) { return &list{ dbService: dbService, state: &state, - limit: 200, - maxID: "", + page: paging.Page{Limit: 200}, localOnly: localOnly, remoteOnly: remoteOnly, out: bufio.NewWriter(os.Stdout), diff --git a/internal/cleaner/emoji.go b/internal/cleaner/emoji.go index d2baec7e8..62ed0f012 100644 --- a/internal/cleaner/emoji.go +++ b/internal/cleaner/emoji.go @@ -27,6 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // Emoji encompasses a set of @@ -105,8 +106,9 @@ func (e *Emoji) UncacheRemote(ctx context.Context, olderThan time.Time) (int, er return total, gtserror.Newf("error getting remote emoji: %w", err) } - if len(emojis) == 0 { - // reached end. + // If no emojis / same group is returned, we reached the end. + if len(emojis) == 0 || + olderThan.Equal(emojis[len(emojis)-1].CreatedAt) { break } @@ -140,23 +142,30 @@ func (e *Emoji) UncacheRemote(ctx context.Context, olderThan time.Time) (int, er func (e *Emoji) FixBroken(ctx context.Context) (int, error) { var ( total int - maxID string + page paging.Page ) + // Set page select limit. + page.Limit = selectLimit + for { - // Fetch the next batch of emoji media up to next ID. - emojis, err := e.state.DB.GetEmojis(ctx, maxID, selectLimit) + // Fetch the next batch of emoji to next max ID. + emojis, err := e.state.DB.GetEmojis(ctx, &page) if err != nil && !errors.Is(err, db.ErrNoEntries) { return total, gtserror.Newf("error getting emojis: %w", err) } - if len(emojis) == 0 { - // reached end. + // Get current max ID. + maxID := page.Max.Value + + // If no emoji or the same group is returned, we reached end. + if len(emojis) == 0 || maxID == emojis[len(emojis)-1].ID { break } - // Use last as the next 'maxID' value. + // Use last ID as the next 'maxID'. maxID = emojis[len(emojis)-1].ID + page.Max = paging.MaxID(maxID) for _, emoji := range emojis { // Check / fix missing broken emoji. @@ -182,23 +191,30 @@ func (e *Emoji) FixBroken(ctx context.Context) (int, error) { func (e *Emoji) PruneUnused(ctx context.Context) (int, error) { var ( total int - maxID string + page paging.Page ) + // Set page select limit. + page.Limit = selectLimit + for { - // Fetch the next batch of emoji media up to next ID. - emojis, err := e.state.DB.GetRemoteEmojis(ctx, maxID, selectLimit) + // Fetch the next batch of emoji to next max ID. + emojis, err := e.state.DB.GetRemoteEmojis(ctx, &page) if err != nil && !errors.Is(err, db.ErrNoEntries) { return total, gtserror.Newf("error getting remote emojis: %w", err) } - if len(emojis) == 0 { - // reached end. + // Get current max ID. + maxID := page.Max.Value + + // If no emoji or the same group is returned, we reached end. + if len(emojis) == 0 || maxID == emojis[len(emojis)-1].ID { break } - // Use last as the next 'maxID' value. + // Use last ID as the next 'maxID'. maxID = emojis[len(emojis)-1].ID + page.Max = paging.MaxID(maxID) for _, emoji := range emojis { // Check / prune unused emoji media. @@ -224,23 +240,30 @@ func (e *Emoji) PruneUnused(ctx context.Context) (int, error) { func (e *Emoji) FixCacheStates(ctx context.Context) (int, error) { var ( total int - maxID string + page paging.Page ) + // Set page select limit. + page.Limit = selectLimit + for { - // Fetch the next batch of emoji media up to next ID. - emojis, err := e.state.DB.GetRemoteEmojis(ctx, maxID, selectLimit) + // Fetch the next batch of emoji to next max ID. + emojis, err := e.state.DB.GetRemoteEmojis(ctx, &page) if err != nil && !errors.Is(err, db.ErrNoEntries) { return total, gtserror.Newf("error getting remote emojis: %w", err) } - if len(emojis) == 0 { - // reached end. + // Get current max ID. + maxID := page.Max.Value + + // If no emoji or the same group is returned, we reached end. + if len(emojis) == 0 || maxID == emojis[len(emojis)-1].ID { break } - // Use last as the next 'maxID' value. + // Use last ID as the next 'maxID'. maxID = emojis[len(emojis)-1].ID + page.Max = paging.MaxID(maxID) for _, emoji := range emojis { // Check / fix required emoji cache states. diff --git a/internal/cleaner/media.go b/internal/cleaner/media.go index 6db205d13..f3cda5d87 100644 --- a/internal/cleaner/media.go +++ b/internal/cleaner/media.go @@ -28,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/regexes" "github.com/superseriousbusiness/gotosocial/internal/uris" ) @@ -128,23 +129,30 @@ func (m *Media) PruneOrphaned(ctx context.Context) (int, error) { func (m *Media) PruneUnused(ctx context.Context) (int, error) { var ( total int - maxID string + page paging.Page ) + // Set page select limit. + page.Limit = selectLimit + for { - // Fetch the next batch of media attachments up to next max ID. - attachments, err := m.state.DB.GetAttachments(ctx, maxID, selectLimit) + // Fetch the next batch of media attachments to next maxID. + attachments, err := m.state.DB.GetAttachments(ctx, &page) if err != nil && !errors.Is(err, db.ErrNoEntries) { return total, gtserror.Newf("error getting attachments: %w", err) } - if len(attachments) == 0 { - // reached end. + // Get current max ID. + maxID := page.Max.Value + + // If no attachments or the same group is returned, we reached the end. + if len(attachments) == 0 || maxID == attachments[len(attachments)-1].ID { break } // Use last ID as the next 'maxID' value. maxID = attachments[len(attachments)-1].ID + page.Max = paging.MaxID(maxID) for _, media := range attachments { // Check / prune unused media attachment. @@ -183,8 +191,9 @@ func (m *Media) UncacheRemote(ctx context.Context, olderThan time.Time) (int, er return total, gtserror.Newf("error getting remote attachments: %w", err) } - if len(attachments) == 0 { - // reached end. + // If no attachments / same group is returned, we reached the end. + if len(attachments) == 0 || + olderThan.Equal(attachments[len(attachments)-1].CreatedAt) { break } @@ -215,23 +224,29 @@ func (m *Media) UncacheRemote(ctx context.Context, olderThan time.Time) (int, er func (m *Media) FixCacheStates(ctx context.Context) (int, error) { var ( total int - maxID string + page paging.Page ) + // Set page select limit. + page.Limit = selectLimit + for { // Fetch the next batch of media attachments up to next max ID. - attachments, err := m.state.DB.GetRemoteAttachments(ctx, maxID, selectLimit) + attachments, err := m.state.DB.GetRemoteAttachments(ctx, &page) if err != nil && !errors.Is(err, db.ErrNoEntries) { return total, gtserror.Newf("error getting remote attachments: %w", err) } + // Get current max ID. + maxID := page.Max.Value - if len(attachments) == 0 { - // reached end. + // If no attachments or the same group is returned, we reached the end. + if len(attachments) == 0 || maxID == attachments[len(attachments)-1].ID { break } // Use last ID as the next 'maxID' value. maxID = attachments[len(attachments)-1].ID + page.Max = paging.MaxID(maxID) for _, media := range attachments { // Check / fix required media cache states. diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index d1cb9dfbd..608cb6417 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -30,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" @@ -326,8 +327,11 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable return e.GetEmojisByIDs(ctx, emojiIDs) } -func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) { - var emojiIDs []string +func (e *emojiDB) GetEmojis(ctx context.Context, page *paging.Page) ([]*gtsmodel.Emoji, error) { + maxID := page.GetMax() + limit := page.GetLimit() + + emojiIDs := make([]string, 0, limit) q := e.db.NewSelect(). Table("emojis"). @@ -349,8 +353,11 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt return e.GetEmojisByIDs(ctx, emojiIDs) } -func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) { - var emojiIDs []string +func (e *emojiDB) GetRemoteEmojis(ctx context.Context, page *paging.Page) ([]*gtsmodel.Emoji, error) { + maxID := page.GetMax() + limit := page.GetLimit() + + emojiIDs := make([]string, 0, limit) q := e.db.NewSelect(). Table("emojis"). diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index ce3c90083..ced38a588 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -27,6 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" @@ -232,7 +233,10 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { return err } -func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { +func (m *mediaDB) GetAttachments(ctx context.Context, page *paging.Page) ([]*gtsmodel.MediaAttachment, error) { + maxID := page.GetMax() + limit := page.GetLimit() + attachmentIDs := make([]string, 0, limit) q := m.db.NewSelect(). @@ -255,7 +259,10 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ( return m.GetAttachmentsByIDs(ctx, attachmentIDs) } -func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { +func (m *mediaDB) GetRemoteAttachments(ctx context.Context, page *paging.Page) ([]*gtsmodel.MediaAttachment, error) { + maxID := page.GetMax() + limit := page.GetLimit() + attachmentIDs := make([]string, 0, limit) q := m.db.NewSelect(). diff --git a/internal/db/emoji.go b/internal/db/emoji.go index fed956933..26a881dbc 100644 --- a/internal/db/emoji.go +++ b/internal/db/emoji.go @@ -22,6 +22,7 @@ import ( "time" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // EmojiAllDomains can be used as the `domain` value in a GetEmojis @@ -47,10 +48,10 @@ type Emoji interface { GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, error) // GetEmojis fetches all emojis with IDs less than 'maxID', up to a maximum of 'limit' emojis. - GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) + GetEmojis(ctx context.Context, page *paging.Page) ([]*gtsmodel.Emoji, error) // GetRemoteEmojis fetches all remote emojis with IDs less than 'maxID', up to a maximum of 'limit' emojis. - GetRemoteEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) + GetRemoteEmojis(ctx context.Context, page *paging.Page) ([]*gtsmodel.Emoji, error) // GetCachedEmojisOlderThan fetches all cached remote emojis with 'updated_at' greater than 'olderThan', up to a maximum of 'limit' emojis. GetCachedEmojisOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.Emoji, error) diff --git a/internal/db/media.go b/internal/db/media.go index 0ef03226b..a41f8970a 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -22,6 +22,7 @@ import ( "time" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // Media contains functions related to creating/getting/removing media attachments. @@ -42,10 +43,10 @@ type Media interface { DeleteAttachment(ctx context.Context, id string) error // GetAttachments fetches media attachments up to a given max ID, and at most limit. - GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) + GetAttachments(ctx context.Context, page *paging.Page) ([]*gtsmodel.MediaAttachment, error) // GetRemoteAttachments fetches media attachments with a non-empty domain, up to a given max ID, and at most limit. - GetRemoteAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) + GetRemoteAttachments(ctx context.Context, page *paging.Page) ([]*gtsmodel.MediaAttachment, error) // GetCachedAttachmentsOlderThan gets limit n remote attachments (including avatars and headers) older than // the given time. These will be returned in order of attachment.created_at descending (i.e. newest to oldest). diff --git a/internal/util/slices.go b/internal/util/slices.go index 51d560dbd..0505229e5 100644 --- a/internal/util/slices.go +++ b/internal/util/slices.go @@ -17,6 +17,8 @@ package util +import "slices" + // Deduplicate deduplicates entries in the given slice. func Deduplicate[T comparable](in []T) []T { var ( @@ -47,6 +49,10 @@ func DeduplicateFunc[T any, C comparable](in []T, key func(v T) C) []T { deduped = make([]T, 0, inL) ) + if key == nil { + panic("nil func") + } + for _, v := range in { k := key(v) @@ -66,6 +72,10 @@ func DeduplicateFunc[T any, C comparable](in []T, key func(v T) C) []T { // passing each item to 'get' and deduplicating the end result. // Compared to Deduplicate() this returns []K, NOT input type []T. func Collate[T any, K comparable](in []T, get func(T) K) []K { + if get == nil { + panic("nil func") + } + ks := make([]K, 0, len(in)) km := make(map[K]struct{}, len(in)) @@ -86,50 +96,25 @@ func Collate[T any, K comparable](in []T, get func(T) K) []K { // OrderBy orders a slice of given type by the provided alternative slice of comparable type. func OrderBy[T any, K comparable](in []T, keys []K, key func(T) K) { - var ( - start int - offset int - ) - - for i := 0; i < len(keys); i++ { - var ( - // key at index. - k = keys[i] - - // sentinel - // idx value. - idx = -1 - ) - - // Look for model with key in slice. - for j := start; j < len(in); j++ { - if key(in[j]) == k { - idx = j - break - } - } - - if idx == -1 { - // model with key - // was not found. - offset++ - continue - } - - // Update - // start - start++ - - // Expected ID index. - exp := i - offset - - if idx == exp { - // Model is in expected - // location, keep going. - continue - } - - // Swap models at current and expected. - in[idx], in[exp] = in[exp], in[idx] + if key == nil { + panic("nil func") } + + // Create lookup of keys->idx. + m := make(map[K]int, len(in)) + for i, k := range keys { + m[k] = i + } + + // Sort according to the reverse lookup. + slices.SortFunc(in, func(a, b T) int { + ai := m[key(a)] + bi := m[key(b)] + if ai < bi { + return -1 + } else if bi < ai { + return +1 + } + return 0 + }) }