From 5f3e0957179eddd088e82b8f8f493164cbc9ce37 Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Tue, 25 Jul 2023 09:34:05 +0100 Subject: [PATCH] [performance] retry db queries on busy errors (#2025) * catch SQLITE_BUSY errors, wrap bun.DB to use our own busy retrier, remove unnecessary db.Error type Signed-off-by: kim * remove dead code Signed-off-by: kim * remove more dead code, add missing error arguments Signed-off-by: kim * update sqlite to use maxOpenConns() Signed-off-by: kim * add uncommitted changes Signed-off-by: kim * use direct calls-through for the ConnIface to make sure we don't double query hook Signed-off-by: kim * expose underlying bun.DB better Signed-off-by: kim * retry on the correct busy error Signed-off-by: kim * use longer possible maxRetries for db retry-backoff Signed-off-by: kim * remove the note regarding max-open-conns only applying to postgres Signed-off-by: kim * improved code commenting Signed-off-by: kim * remove unnecessary infof call (just use info) Signed-off-by: kim * rename DBConn to WrappedDB to better follow sql package name conventions Signed-off-by: kim * update test error string checks Signed-off-by: kim * shush linter Signed-off-by: kim * update backoff logic to be more transparent Signed-off-by: kim --------- Signed-off-by: kim --- example/config.yaml | 4 - internal/db/account.go | 46 ++-- internal/db/admin.go | 10 +- internal/db/basic.go | 26 +- internal/db/bundb/account.go | 120 ++++---- internal/db/bundb/account_test.go | 4 +- internal/db/bundb/admin.go | 36 +-- internal/db/bundb/basic.go | 70 ++--- internal/db/bundb/bundb.go | 92 +++---- internal/db/bundb/conn.go | 113 -------- internal/db/bundb/domain.go | 42 +-- internal/db/bundb/emoji.go | 90 +++--- internal/db/bundb/errors.go | 24 +- internal/db/bundb/instance.go | 56 ++-- internal/db/bundb/list.go | 40 +-- internal/db/bundb/media.go | 60 ++-- internal/db/bundb/mention.go | 18 +- internal/db/bundb/notification.go | 50 ++-- internal/db/bundb/relationship.go | 80 +++--- internal/db/bundb/relationship_block.go | 30 +- internal/db/bundb/relationship_follow.go | 28 +- internal/db/bundb/relationship_follow_req.go | 50 ++-- internal/db/bundb/report.go | 34 +-- internal/db/bundb/search.go | 20 +- internal/db/bundb/session.go | 15 +- internal/db/bundb/status.go | 88 +++--- internal/db/bundb/statusbookmark.go | 45 ++- internal/db/bundb/statusfave.go | 50 ++-- internal/db/bundb/timeline.go | 28 +- internal/db/bundb/tombstone.go | 22 +- internal/db/bundb/user.go | 56 ++-- internal/db/bundb/wrap.go | 258 ++++++++++++++++++ internal/db/domain.go | 16 +- internal/db/emoji.go | 26 +- internal/db/error.go | 24 +- internal/db/instance.go | 14 +- internal/db/media.go | 12 +- internal/db/mention.go | 4 +- internal/db/notification.go | 12 +- internal/db/relationship.go | 14 +- internal/db/report.go | 10 +- internal/db/session.go | 2 +- internal/db/status.go | 32 +-- internal/db/statusbookmark.go | 14 +- internal/db/statusfave.go | 14 +- internal/db/timeline.go | 6 +- internal/db/tombstone.go | 8 +- internal/db/user.go | 18 +- .../federation/dereferencing/account_test.go | 7 +- internal/httpclient/client.go | 2 +- internal/middleware/signaturecheck.go | 3 +- internal/processing/stream/authorize_test.go | 3 +- internal/web/web.go | 2 +- 53 files changed, 1050 insertions(+), 898 deletions(-) delete mode 100644 internal/db/bundb/conn.go create mode 100644 internal/db/bundb/wrap.go diff --git a/example/config.yaml b/example/config.yaml index 5f41952af..ce471ffed 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -194,10 +194,6 @@ db-tls-ca-cert: "" # # If you set the multiplier to less than 1, only one open connection will be used regardless of cpu count. # -# PLEASE NOTE!!: This setting currently only applies for Postgres. SQLite will always use 1 connection regardless -# of what is set here. This behavior will change in future when we implement better SQLITE_BUSY handling. -# See https://github.com/superseriousbusiness/gotosocial/issues/1407 for more details. -# # Examples: [16, 8, 10, 2] # Default: 8 db-max-open-conns-multiplier: 8 diff --git a/internal/db/account.go b/internal/db/account.go index 2e113c35e..21b8d6a1f 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -27,67 +27,67 @@ import ( // Account contains functions related to account getting/setting/creation. type Account interface { // GetAccountByID returns one account with the given ID, or an error if something goes wrong. - GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error) + GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error) // GetAccountByURI returns one account with the given URI, or an error if something goes wrong. - GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error) // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. - GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) + GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, error) // GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong. - GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error) + GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) // GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong. - GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error) + GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error) // GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong. - GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) // GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong. - GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) // GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong. - GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) // GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong. - GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) // PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc). PopulateAccount(ctx context.Context, account *gtsmodel.Account) error // PutAccount puts one account in the database. - PutAccount(ctx context.Context, account *gtsmodel.Account) Error + PutAccount(ctx context.Context, account *gtsmodel.Account) error // UpdateAccount updates one account by ID. - UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) Error + UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error // DeleteAccount deletes one account from the database by its ID. // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the // account as suspended instead, rather than deleting from the db entirely. - DeleteAccount(ctx context.Context, id string) Error + DeleteAccount(ctx context.Context, id string) error // GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username. - GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) + GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error) // GetAccountFaves fetches faves/likes created by the target accountID. - GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error) + GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error) // GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column. GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error) // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. - CountAccountStatuses(ctx context.Context, accountID string) (int, Error) + CountAccountStatuses(ctx context.Context, accountID string) (int, error) // CountAccountPinned returns the total number of pinned statuses owned by account with the given id. - CountAccountPinned(ctx context.Context, accountID string) (int, Error) + CountAccountPinned(ctx context.Context, accountID string) (int, error) // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! // // In the case of no statuses, this function will return db.ErrNoEntries. - GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, Error) + GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) // GetAccountPinnedStatuses returns ONLY statuses owned by the give accountID for which a corresponding StatusPin // exists in the database. Statuses which are not pinned will not be returned by this function. @@ -95,28 +95,28 @@ type Account interface { // Statuses will be returned in the order in which they were pinned, from latest pinned to oldest pinned (descending). // // In the case of no statuses, this function will return db.ErrNoEntries. - GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, Error) + GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) // GetAccountWebStatuses is similar to GetAccountStatuses, but it's specifically for returning statuses that // should be visible via the web view of an account. So, only public, federated statuses that aren't boosts // or replies. // // In the case of no statuses, this function will return db.ErrNoEntries. - GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, Error) + GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) - GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) + GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. // // If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned. // // The returned time will be zero if account has never posted anything. - GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, Error) + GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. - SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error + SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error // GetInstanceAccount returns the instance account for the given domain. // If domain is empty, this instance account will be returned. - GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error) + GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) } diff --git a/internal/db/admin.go b/internal/db/admin.go index 57ded68b1..717ac4b94 100644 --- a/internal/db/admin.go +++ b/internal/db/admin.go @@ -27,26 +27,26 @@ import ( type Admin interface { // IsUsernameAvailable checks whether a given username is available on our domain. // Returns an error if the username is already taken, or something went wrong in the db. - IsUsernameAvailable(ctx context.Context, username string) (bool, Error) + IsUsernameAvailable(ctx context.Context, username string) (bool, error) // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. // Return an error if: // A) the email is already associated with an account // B) we block signups from this email domain // C) something went wrong in the db - IsEmailAvailable(ctx context.Context, email string) (bool, Error) + IsEmailAvailable(ctx context.Context, email string) (bool, error) // NewSignup creates a new user in the database with the given parameters. // By the time this function is called, it should be assumed that all the parameters have passed validation! - NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, Error) + NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) // CreateInstanceAccount creates an account in the database with the same username as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. // This is needed for things like serving files that belong to the instance and not an individual user/account. - CreateInstanceAccount(ctx context.Context) Error + CreateInstanceAccount(ctx context.Context) error // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. // This is needed for things like serving instance information through /api/v1/instance - CreateInstanceInstance(ctx context.Context) Error + CreateInstanceInstance(ctx context.Context) error } diff --git a/internal/db/basic.go b/internal/db/basic.go index 3782f3621..f8c04c6b9 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -23,58 +23,58 @@ import "context" type Basic interface { // CreateTable creates a table for the given interface. // For implementations that don't use tables, this can just return nil. - CreateTable(ctx context.Context, i interface{}) Error + CreateTable(ctx context.Context, i interface{}) error // CreateAllTables creates *all* tables necessary for the running of GoToSocial. // Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized. - CreateAllTables(ctx context.Context) Error + CreateAllTables(ctx context.Context) error // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. - DropTable(ctx context.Context, i interface{}) Error + DropTable(ctx context.Context, i interface{}) error // Stop should stop and close the database connection cleanly, returning an error if this is not possible. // If the database implementation doesn't need to be stopped, this can just return nil. - Stop(ctx context.Context) Error + Stop(ctx context.Context) error // IsHealthy should return nil if the database connection is healthy, or an error if not. - IsHealthy(ctx context.Context) Error + IsHealthy(ctx context.Context) error // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry, // for other implementations (for example, in-memory) it might just be the key of a map. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetByID(ctx context.Context, id string, i interface{}) Error + GetByID(ctx context.Context, id string, i interface{}) error // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the // name of the key to select from. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetWhere(ctx context.Context, where []Where, i interface{}) Error + GetWhere(ctx context.Context, where []Where, i interface{}) error // GetAll will try to get all entries of type i. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetAll(ctx context.Context, i interface{}) Error + GetAll(ctx context.Context, i interface{}) error // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Put(ctx context.Context, i interface{}) Error + Put(ctx context.Context, i interface{}) error // UpdateByID updates values of i based on its id. // If any columns are specified, these will be updated exclusively. // Otherwise, the whole model will be updated. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error + UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) error // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. - UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error + UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) error // DeleteByID removes i with id id. // If i didn't exist anyway, then no error should be returned. - DeleteByID(ctx context.Context, id string, i interface{}) Error + DeleteByID(ctx context.Context, id string, i interface{}) error // DeleteWhere deletes i where key = value // If i didn't exist anyway, then no error should be returned. - DeleteWhere(ctx context.Context, where []Where, i interface{}) Error + DeleteWhere(ctx context.Context, where []Where, i interface{}) error } diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 179db6bb3..2ef1618db 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -38,16 +38,16 @@ import ( ) type accountDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "ID", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.id"), id). Scan(ctx) @@ -77,12 +77,12 @@ func (a *accountDB) GetAccountsByIDs(ctx context.Context, ids []string) ([]*gtsm return accounts, nil } -func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "URI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.uri"), uri). Scan(ctx) @@ -91,12 +91,12 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. ) } -func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "URL", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.url"), url). Scan(ctx) @@ -105,7 +105,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. ) } -func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) { if domain != "" { // Normalize the domain as punycode var err error @@ -119,7 +119,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str ctx, "Username.Domain", func(account *gtsmodel.Account) error { - q := a.conn.NewSelect(). + q := a.db.NewSelect(). Model(account) if domain != "" { @@ -139,12 +139,12 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str ) } -func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "PublicKeyURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.public_key_uri"), id). Scan(ctx) @@ -153,12 +153,12 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo ) } -func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "InboxURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.inbox_uri"), uri). Scan(ctx) @@ -167,12 +167,12 @@ func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsm ) } -func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "OutboxURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.outbox_uri"), uri). Scan(ctx) @@ -181,12 +181,12 @@ func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gts ) } -func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "FollowersURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.followers_uri"), uri). Scan(ctx) @@ -195,12 +195,12 @@ func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (* ) } -func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "FollowingURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.following_uri"), uri). Scan(ctx) @@ -209,7 +209,7 @@ func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (* ) } -func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) { var username string if domain == "" { @@ -223,14 +223,14 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts return a.GetAccountByUsernameDomain(ctx, username, domain) } -func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) { +func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) { // Fetch account from database cache with loader callback account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) { var account gtsmodel.Account // Not cached! Perform database query if err := dbQuery(&account); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return &account, nil @@ -294,12 +294,12 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou return errs.Combine() } -func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error { return a.state.Caches.GTS.Account().Store(account, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + return a.db.RunInTx(ctx, func(tx bun.Tx) error { // create links between this account and any emojis it uses for _, i := range account.EmojiIDs { if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ @@ -317,7 +317,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) d }) } -func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) db.Error { +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error { account.UpdatedAt = time.Now() if len(columns) > 0 { // If we're updating by column, ensure "updated_at" is included. @@ -328,7 +328,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + return a.db.RunInTx(ctx, func(tx bun.Tx) error { // create links between this account and any emojis it uses // first clear out any old emoji links if _, err := tx. @@ -362,7 +362,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account }) } -func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { +func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { defer a.state.Caches.GTS.Account().Invalidate("ID", id) // Load account into cache before attempting a delete, @@ -376,7 +376,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { return err } - return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + return a.db.RunInTx(ctx, func(tx bun.Tx) error { // clear out any emoji links if _, err := tx. NewDelete(). @@ -396,10 +396,10 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { }) } -func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) { +func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) { createdAt := time.Time{} - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Column("status.created_at"). @@ -416,12 +416,12 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, } if err := q.Scan(ctx, &createdAt); err != nil { - return time.Time{}, a.conn.ProcessError(err) + return time.Time{}, a.db.ProcessError(err) } return createdAt, nil } -func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { +func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error { if *mediaAttachment.Avatar && *mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") } @@ -437,26 +437,26 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen } // TODO: there are probably more side effects here that need to be handled - if _, err := a.conn. + if _, err := a.db. NewInsert(). Model(mediaAttachment). Exec(ctx); err != nil { - return a.conn.ProcessError(err) + return a.db.ProcessError(err) } - if _, err := a.conn. + if _, err := a.db. NewUpdate(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Set("? = ?", column, mediaAttachment.ID). Where("? = ?", bun.Ident("account.id"), accountID). Exec(ctx); err != nil { - return a.conn.ProcessError(err) + return a.db.ProcessError(err) } return nil } -func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, db.Error) { +func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error) { account, err := a.GetAccountByUsernameDomain(ctx, username, "") if err != nil { return "", err @@ -469,7 +469,7 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) ( var accountIDs []string // Create SELECT account query. - q := a.conn.NewSelect(). + q := a.db.NewSelect(). Table("accounts"). Column("id") @@ -486,37 +486,37 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) ( // Execute the query, scanning destination into accountIDs. if _, err := q.Exec(ctx, &accountIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } // Convert account IDs into account objects. return a.GetAccountsByIDs(ctx, accountIDs) } -func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { +func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error) { faves := new([]*gtsmodel.StatusFave) - if err := a.conn. + if err := a.db. NewSelect(). Model(faves). Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Scan(ctx); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return *faves, nil } -func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { - return a.conn. +func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) { + return a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Where("? = ?", bun.Ident("status.account_id"), accountID). Count(ctx) } -func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, db.Error) { - return a.conn. +func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) { + return a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Where("? = ?", bun.Ident("status.account_id"), accountID). @@ -524,7 +524,7 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i Count(ctx) } -func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -536,7 +536,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li frontToBack = true ) - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table @@ -562,7 +562,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li // implementation differs between SQLite and Postgres, // so we have to be thorough to cover all eventualities q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { - switch a.conn.Dialect().Name() { + switch a.db.Dialect().Name() { case dialect.PG: return q. Where("? IS NOT NULL", bun.Ident("status.attachments")). @@ -613,7 +613,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } // If we're paging up, we still want statuses @@ -628,10 +628,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li return a.statusesFromIDs(ctx, statusIDs) } -func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) { statusIDs := []string{} - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Column("status.id"). @@ -640,13 +640,13 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri Order("status.pinned_at DESC") if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return a.statusesFromIDs(ctx, statusIDs) } -func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -655,7 +655,7 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, // Make educated guess for slice size statusIDs := make([]string, 0, limit) - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table @@ -688,16 +688,16 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, q = q.Order("status.id DESC") if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return a.statusesFromIDs(ctx, statusIDs) } -func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { +func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) { blocks := []*gtsmodel.Block{} - fq := a.conn. + fq := a.db. NewSelect(). Model(&blocks). Where("? = ?", bun.Ident("block.account_id"), accountID). @@ -717,7 +717,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI } if err := fq.Scan(ctx); err != nil { - return nil, "", "", a.conn.ProcessError(err) + return nil, "", "", a.db.ProcessError(err) } if len(blocks) == 0 { @@ -734,7 +734,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI return accounts, nextMaxID, prevMinID, nil } -func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) { // Catch case of no statuses early if len(statusIDs) == 0 { return nil, db.ErrNoEntries diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index bfe6df536..b410bb3ed 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -260,7 +260,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { } noCache := >smodel.Account{} - err = dbService.GetConn(). + err = dbService.DB(). NewSelect(). Model(noCache). Where("? = ?", bun.Ident("account.id"), testAccount.ID). @@ -288,7 +288,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { suite.Empty(updated.EmojiIDs) suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) - err = dbService.GetConn(). + err = dbService.DB(). NewSelect(). Model(noCache). Where("? = ?", bun.Ident("account.id"), testAccount.ID). diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 61ae1e044..fb1fb9d6c 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -44,21 +44,21 @@ import ( const rsaKeyBits = 2048 type adminDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { - q := a.conn. +func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) { + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Column("account.id"). Where("? = ?", bun.Ident("account.username"), username). Where("? IS NULL", bun.Ident("account.domain")) - return a.conn.NotExists(ctx, q) + return a.db.NotExists(ctx, q) } -func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) { +func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, error) { // parse the domain from the email m, err := mail.ParseAddress(email) if err != nil { @@ -67,12 +67,12 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ // check if the email domain is blocked - emailDomainBlockedQ := a.conn. + emailDomainBlockedQ := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). Column("email_domain_block.id"). Where("? = ?", bun.Ident("email_domain_block.domain"), domain) - emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ) + emailDomainBlocked, err := a.db.Exists(ctx, emailDomainBlockedQ) if err != nil { return false, err } @@ -81,16 +81,16 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. } // check if this email is associated with a user already - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). Column("user.id"). Where("? = ?", bun.Ident("user.email"), email). WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) - return a.conn.NotExists(ctx, q) + return a.db.NotExists(ctx, q) } -func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, db.Error) { +func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) { // If something went wrong previously while doing a new // sign up with this username, we might already have an // account, so check first. @@ -220,17 +220,17 @@ func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) ( return user, nil } -func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { +func (a *adminDB) CreateInstanceAccount(ctx context.Context) error { username := config.GetHost() - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Column("account.id"). Where("? = ?", bun.Ident("account.username"), username). Where("? IS NULL", bun.Ident("account.domain")) - exists, err := a.conn.Exists(ctx, q) + exists, err := a.db.Exists(ctx, q) if err != nil { return err } @@ -277,18 +277,18 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { return nil } -func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { +func (a *adminDB) CreateInstanceInstance(ctx context.Context) error { protocol := config.GetProtocol() host := config.GetHost() // check if instance entry already exists - q := a.conn. + q := a.db. NewSelect(). Column("instance.id"). TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). Where("? = ?", bun.Ident("instance.domain"), host) - exists, err := a.conn.Exists(ctx, q) + exists, err := a.db.Exists(ctx, q) if err != nil { return err } @@ -309,13 +309,13 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { URI: fmt.Sprintf("%s://%s", protocol, host), } - insertQ := a.conn. + insertQ := a.db. NewInsert(). Model(i) _, err = insertQ.Exec(ctx) if err != nil { - return a.conn.ProcessError(err) + return a.db.ProcessError(err) } log.Infof(ctx, "created instance instance %s with id %s", host, i.ID) diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 6406ede35..4991dcf69 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -28,99 +28,99 @@ import ( ) type basicDB struct { - conn *DBConn + db *WrappedDB } -func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { - _, err := b.conn.NewInsert().Model(i).Exec(ctx) - return b.conn.ProcessError(err) +func (b *basicDB) Put(ctx context.Context, i interface{}) error { + _, err := b.db.NewInsert().Model(i).Exec(ctx) + return b.db.ProcessError(err) } -func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error { - q := b.conn. +func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) error { + q := b.db. NewSelect(). Model(i). Where("id = ?", id) err := q.Scan(ctx) - return b.conn.ProcessError(err) + return b.db.ProcessError(err) } -func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { +func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) error { if len(where) == 0 { return errors.New("no queries provided") } - q := b.conn.NewSelect().Model(i) + q := b.db.NewSelect().Model(i) selectWhere(q, where) err := q.Scan(ctx) - return b.conn.ProcessError(err) + return b.db.ProcessError(err) } -func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { - q := b.conn. +func (b *basicDB) GetAll(ctx context.Context, i interface{}) error { + q := b.db. NewSelect(). Model(i) err := q.Scan(ctx) - return b.conn.ProcessError(err) + return b.db.ProcessError(err) } -func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { - q := b.conn. +func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) error { + q := b.db. NewDelete(). Model(i). Where("id = ?", id) _, err := q.Exec(ctx) - return b.conn.ProcessError(err) + return b.db.ProcessError(err) } -func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { +func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) error { if len(where) == 0 { return errors.New("no queries provided") } - q := b.conn. + q := b.db. NewDelete(). Model(i) deleteWhere(q, where) _, err := q.Exec(ctx) - return b.conn.ProcessError(err) + return b.db.ProcessError(err) } -func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error { - q := b.conn. +func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) error { + q := b.db. NewUpdate(). Model(i). Column(columns...). Where("? = ?", bun.Ident("id"), id) _, err := q.Exec(ctx) - return b.conn.ProcessError(err) + return b.db.ProcessError(err) } -func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { - q := b.conn.NewUpdate().Model(i) +func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) error { + q := b.db.NewUpdate().Model(i) updateWhere(q, where) q = q.Set("? = ?", bun.Ident(key), value) _, err := q.Exec(ctx) - return b.conn.ProcessError(err) + return b.db.ProcessError(err) } -func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { - _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx) +func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error { + _, err := b.db.NewCreateTable().Model(i).IfNotExists().Exec(ctx) return err } -func (b *basicDB) CreateAllTables(ctx context.Context) db.Error { +func (b *basicDB) CreateAllTables(ctx context.Context) error { models := []interface{}{ >smodel.Account{}, >smodel.Application{}, @@ -154,16 +154,16 @@ func (b *basicDB) CreateAllTables(ctx context.Context) db.Error { return nil } -func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { - _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) - return b.conn.ProcessError(err) +func (b *basicDB) DropTable(ctx context.Context, i interface{}) error { + _, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx) + return b.db.ProcessError(err) } -func (b *basicDB) IsHealthy(ctx context.Context) db.Error { - return b.conn.PingContext(ctx) +func (b *basicDB) IsHealthy(ctx context.Context) error { + return b.db.DB.PingContext(ctx) } -func (b *basicDB) Stop(ctx context.Context) db.Error { +func (b *basicDB) Stop(ctx context.Context) error { log.Info(ctx, "closing db connection") - return b.conn.Close() + return b.db.DB.Close() } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index ee28800b5..5634f877f 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -79,13 +79,13 @@ type DBService struct { db.Timeline db.User db.Tombstone - conn *DBConn + db *WrappedDB } -// GetConn returns the underlying bun connection. +// GetDB returns the underlying database connection pool. // Should only be used in testing + exceptional circumstance. -func (dbService *DBService) GetConn() *DBConn { - return dbService.conn +func (dbService *DBService) DB() *WrappedDB { + return dbService.db } func doMigration(ctx context.Context, db *bun.DB) error { @@ -112,18 +112,18 @@ func doMigration(ctx context.Context, db *bun.DB) error { // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { - var conn *DBConn + var db *WrappedDB var err error t := strings.ToLower(config.GetDbType()) switch t { case "postgres": - conn, err = pgConn(ctx) + db, err = pgConn(ctx) if err != nil { return nil, err } case "sqlite": - conn, err = sqliteConn(ctx) + db, err = sqliteConn(ctx) if err != nil { return nil, err } @@ -132,15 +132,15 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { } // Add database query hooks. - conn.DB.AddQueryHook(queryHook{}) + db.AddQueryHook(queryHook{}) if config.GetTracingEnabled() { - conn.DB.AddQueryHook(tracing.InstrumentBun()) + db.AddQueryHook(tracing.InstrumentBun()) } // execute sqlite pragmas *after* adding database hook; // this allows the pragma queries to be logged if t == "sqlite" { - if err := sqlitePragmas(ctx, conn); err != nil { + if err := sqlitePragmas(ctx, db); err != nil { return nil, err } } @@ -148,103 +148,103 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { // table registration is needed for many-to-many, see: // https://bun.uptrace.dev/orm/many-to-many-relation/ for _, t := range registerTables { - conn.RegisterModel(t) + db.RegisterModel(t) } // perform any pending database migrations: this includes // the very first 'migration' on startup which just creates // necessary tables - if err := doMigration(ctx, conn.DB); err != nil { + if err := doMigration(ctx, db.DB); err != nil { return nil, fmt.Errorf("db migration error: %s", err) } ps := &DBService{ Account: &accountDB{ - conn: conn, + db: db, state: state, }, Admin: &adminDB{ - conn: conn, + db: db, state: state, }, Basic: &basicDB{ - conn: conn, + db: db, }, Domain: &domainDB{ - conn: conn, + db: db, state: state, }, Emoji: &emojiDB{ - conn: conn, + db: db, state: state, }, Instance: &instanceDB{ - conn: conn, + db: db, state: state, }, List: &listDB{ - conn: conn, + db: db, state: state, }, Media: &mediaDB{ - conn: conn, + db: db, state: state, }, Mention: &mentionDB{ - conn: conn, + db: db, state: state, }, Notification: ¬ificationDB{ - conn: conn, + db: db, state: state, }, Relationship: &relationshipDB{ - conn: conn, + db: db, state: state, }, Report: &reportDB{ - conn: conn, + db: db, state: state, }, Search: &searchDB{ - conn: conn, + db: db, state: state, }, Session: &sessionDB{ - conn: conn, + db: db, }, Status: &statusDB{ - conn: conn, + db: db, state: state, }, StatusBookmark: &statusBookmarkDB{ - conn: conn, + db: db, state: state, }, StatusFave: &statusFaveDB{ - conn: conn, + db: db, state: state, }, Timeline: &timelineDB{ - conn: conn, + db: db, state: state, }, User: &userDB{ - conn: conn, + db: db, state: state, }, Tombstone: &tombstoneDB{ - conn: conn, + db: db, state: state, }, - conn: conn, + db: db, } // we can confidently return this useable service now return ps, nil } -func pgConn(ctx context.Context) (*DBConn, error) { +func pgConn(ctx context.Context) (*WrappedDB, error) { opts, err := deriveBunDBPGOptions() //nolint:contextcheck if err != nil { return nil, fmt.Errorf("could not create bundb postgres options: %s", err) @@ -259,10 +259,10 @@ func pgConn(ctx context.Context) (*DBConn, error) { sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections - conn := WrapDBConn(bun.NewDB(sqldb, pgdialect.New())) + conn := WrapDB(bun.NewDB(sqldb, pgdialect.New())) // ping to check the db is there and listening - if err := conn.PingContext(ctx); err != nil { + if err := conn.DB.PingContext(ctx); err != nil { return nil, fmt.Errorf("postgres ping: %s", err) } @@ -270,7 +270,7 @@ func pgConn(ctx context.Context) (*DBConn, error) { return conn, nil } -func sqliteConn(ctx context.Context) (*DBConn, error) { +func sqliteConn(ctx context.Context) (*WrappedDB, error) { // validate db address has actually been set address := config.GetDbAddress() if address == "" { @@ -326,15 +326,15 @@ func sqliteConn(ctx context.Context) (*DBConn, error) { // Tune db connections for sqlite, see: // - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql // - https://www.alexedwards.net/blog/configuring-sqldb - sqldb.SetMaxOpenConns(1) // only 1 connection regardless of multiplier, see https://github.com/superseriousbusiness/gotosocial/issues/1407 - sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around - sqldb.SetConnMaxLifetime(0) // don't kill connections due to age + sqldb.SetMaxOpenConns(maxOpenConns()) // x number of conns per CPU + sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around + sqldb.SetConnMaxLifetime(0) // don't kill connections due to age // Wrap Bun database conn in our own wrapper - conn := WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New())) + conn := WrapDB(bun.NewDB(sqldb, sqlitedialect.New())) // ping to check the db is there and listening - if err := conn.PingContext(ctx); err != nil { + if err := conn.DB.PingContext(ctx); err != nil { if errWithCode, ok := err.(*sqlite.Error); ok { err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) } @@ -445,7 +445,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { // sqlitePragmas sets desired sqlite pragmas based on configured values, and // logs the results of the pragma queries. Errors if something goes wrong. -func sqlitePragmas(ctx context.Context, conn *DBConn) error { +func sqlitePragmas(ctx context.Context, db *WrappedDB) error { var pragmas [][]string if mode := config.GetDbSqliteJournalMode(); mode != "" { // Set the user provided SQLite journal mode @@ -475,12 +475,12 @@ func sqlitePragmas(ctx context.Context, conn *DBConn) error { pk := p[0] pv := p[1] - if _, err := conn.DB.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil { + if _, err := db.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil { return fmt.Errorf("error executing sqlite pragma %s: %w", pk, err) } var res string - if err := conn.DB.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil { + if err := db.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil { return fmt.Errorf("error scanning sqlite pragma %s: %w", pv, err) } @@ -502,7 +502,7 @@ func (dbService *DBService) TagStringToTag(ctx context.Context, t string, origin tag := >smodel.Tag{} // we can use selectorinsert here to create the new tag if it doesn't exist already // inserted will be true if this is a new tag we just created - if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows { + if err := dbService.db.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("error getting tag with name %s: %s", t, err) } diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go deleted file mode 100644 index 4a9ec83de..000000000 --- a/internal/db/bundb/conn.go +++ /dev/null @@ -1,113 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package bundb - -import ( - "context" - "database/sql" - - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect" -) - -// DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality -type DBConn struct { - errProc func(error) db.Error // errProc is the SQL-type specific error processor - *bun.DB // DB is the underlying bun.DB connection -} - -// WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect. -func WrapDBConn(dbConn *bun.DB) *DBConn { - var errProc func(error) db.Error - switch dbConn.Dialect().Name() { - case dialect.PG: - errProc = processPostgresError - case dialect.SQLite: - errProc = processSQLiteError - default: - panic("unknown dialect name: " + dbConn.Dialect().Name().String()) - } - return &DBConn{ - errProc: errProc, - DB: dbConn, - } -} - -// RunInTx wraps execution of the supplied transaction function. -func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error { - return conn.ProcessError(func() error { - // Acquire a new transaction - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - return err - } - - var done bool - - defer func() { - if !done { - _ = tx.Rollback() - } - }() - - // Perform supplied transaction - if err := fn(tx); err != nil { - return err - } - - // Finally, commit - err = tx.Commit() //nolint:contextcheck - done = true - return err - }()) -} - -// ProcessError processes an error to replace any known values with our own db.Error types, -// making it easier to catch specific situations (e.g. no rows, already exists, etc) -func (conn *DBConn) ProcessError(err error) db.Error { - switch { - case err == nil: - return nil - case err == sql.ErrNoRows: - return db.ErrNoEntries - default: - return conn.errProc(err) - } -} - -// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors -func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { - exists, err := query.Exists(ctx) - - // Process error as our own and check if it exists - switch err := conn.ProcessError(err); err { - case nil: - return exists, nil - case db.ErrNoEntries: - return false, nil - default: - return false, err - } -} - -// NotExists is the functional opposite of conn.Exists() -func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { - exists, err := conn.Exists(ctx, query) - return !exists, err -} diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 2e8ce2a6b..07e1e9fca 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -30,11 +30,11 @@ import ( ) type domainDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { +func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error { // Normalize the domain as punycode var err error block.Domain, err = util.Punify(block.Domain) @@ -43,10 +43,10 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain } // Attempt to store domain block in DB - if _, err := d.conn.NewInsert(). + if _, err := d.db.NewInsert(). Model(block). Exec(ctx); err != nil { - return d.conn.ProcessError(err) + return d.db.ProcessError(err) } // Clear the domain block cache (for later reload) @@ -55,7 +55,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain return nil } -func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { +func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) { // Normalize the domain as punycode domain, err := util.Punify(domain) if err != nil { @@ -71,12 +71,12 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel var block gtsmodel.DomainBlock // Look for block matching domain in DB - q := d.conn. + q := d.db. NewSelect(). Model(&block). Where("? = ?", bun.Ident("domain_block.domain"), domain) if err := q.Scan(ctx); err != nil { - return nil, d.conn.ProcessError(err) + return nil, d.db.ProcessError(err) } return &block, nil @@ -85,31 +85,31 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) { blocks := []*gtsmodel.DomainBlock{} - if err := d.conn. + if err := d.db. NewSelect(). Model(&blocks). Scan(ctx); err != nil { - return nil, d.conn.ProcessError(err) + return nil, d.db.ProcessError(err) } return blocks, nil } -func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) { +func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, error) { var block gtsmodel.DomainBlock - q := d.conn. + q := d.db. NewSelect(). Model(&block). Where("? = ?", bun.Ident("domain_block.id"), id) if err := q.Scan(ctx); err != nil { - return nil, d.conn.ProcessError(err) + return nil, d.db.ProcessError(err) } return &block, nil } -func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { +func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error { // Normalize the domain as punycode domain, err := util.Punify(domain) if err != nil { @@ -117,11 +117,11 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro } // Attempt to delete domain block - if _, err := d.conn.NewDelete(). + if _, err := d.db.NewDelete(). Model((*gtsmodel.DomainBlock)(nil)). Where("? = ?", bun.Ident("domain_block.domain"), domain). Exec(ctx); err != nil { - return d.conn.ProcessError(err) + return d.db.ProcessError(err) } // Clear the domain block cache (for later reload) @@ -130,7 +130,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro return nil } -func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { +func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, error) { // Normalize the domain as punycode domain, err := util.Punify(domain) if err != nil { @@ -148,18 +148,18 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db var domains []string // Scan list of all blocked domains from DB - q := d.conn.NewSelect(). + q := d.db.NewSelect(). Table("domain_blocks"). Column("domain") if err := q.Scan(ctx, &domains); err != nil { - return nil, d.conn.ProcessError(err) + return nil, d.db.ProcessError(err) } return domains, nil }) } -func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { +func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) { for _, domain := range domains { if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { return false, err @@ -170,11 +170,11 @@ func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (boo return false, nil } -func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) { +func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error) { return d.IsDomainBlocked(ctx, uri.Hostname()) } -func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) { +func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error) { for _, uri := range uris { if blocked, err := d.IsDomainBlocked(ctx, uri.Hostname()); err != nil { return false, err diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 321b5c0e7..90bcd134d 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -34,14 +34,14 @@ import ( ) type emojiDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { +func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error { return e.state.Caches.GTS.Emoji().Store(emoji, func() error { - _, err := e.conn.NewInsert().Model(emoji).Exec(ctx) - return e.conn.ProcessError(err) + _, err := e.db.NewInsert().Model(emoji).Exec(ctx) + return e.db.ProcessError(err) }) } @@ -54,17 +54,17 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column // Update the emoji model in the database. return e.state.Caches.GTS.Emoji().Store(emoji, func() error { - _, err := e.conn. + _, err := e.db. NewUpdate(). Model(emoji). Where("? = ?", bun.Ident("emoji.id"), emoji.ID). Column(columns...). Exec(ctx) - return e.conn.ProcessError(err) + return e.db.ProcessError(err) }) } -func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error { +func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { var ( accountIDs []string statusIDs []string @@ -105,7 +105,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error { return err } - return e.conn.RunInTx(ctx, func(tx bun.Tx) error { + return e.db.RunInTx(ctx, func(tx bun.Tx) error { // delete links between this emoji and any statuses that use it // TODO: remove when we delete this table if _, err := tx. @@ -229,7 +229,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error { func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error) { emojiIDs := []string{} - subQuery := e.conn. + subQuery := e.db. NewSelect(). ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids")) @@ -255,7 +255,7 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable // "emojis" AS "emoji" // ORDER BY // "shortcode_domain" ASC - switch e.conn.Dialect().Name() { + switch e.db.Dialect().Name() { case dialect.SQLite: subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain")) case dialect.PG: @@ -321,12 +321,12 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable // ORDER BY // "shortcode_domain" ASC // ) AS "subquery" - if err := e.conn. + if err := e.db. NewSelect(). Column("subquery.emoji_ids"). TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")). Scan(ctx, &emojiIDs); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } if order == "DESC" { @@ -346,7 +346,7 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) { var emojiIDs []string - q := e.conn.NewSelect(). + q := e.db.NewSelect(). Table("emojis"). Column("id"). Order("id DESC") @@ -360,7 +360,7 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt } if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } return e.GetEmojisByIDs(ctx, emojiIDs) @@ -369,7 +369,7 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) { var emojiIDs []string - q := e.conn.NewSelect(). + q := e.db.NewSelect(). Table("emojis"). Column("id"). Where("domain IS NOT NULL"). @@ -384,7 +384,7 @@ func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) } if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } return e.GetEmojisByIDs(ctx, emojiIDs) @@ -393,7 +393,7 @@ func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.Emoji, error) { var emojiIDs []string - q := e.conn.NewSelect(). + q := e.db.NewSelect(). Table("emojis"). Column("id"). Where("cached = true"). @@ -406,16 +406,16 @@ func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.T } if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } return e.GetEmojisByIDs(ctx, emojiIDs) } -func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, error) { emojiIDs := []string{} - q := e.conn. + q := e.db. NewSelect(). TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")). Column("emoji.id"). @@ -425,18 +425,18 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E Order("emoji.shortcode ASC") if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } return e.GetEmojisByIDs(ctx, emojiIDs) } -func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, error) { return e.getEmoji( ctx, "ID", func(emoji *gtsmodel.Emoji) error { - return e.conn. + return e.db. NewSelect(). Model(emoji). Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) @@ -445,12 +445,12 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, ) } -func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, error) { return e.getEmoji( ctx, "URI", func(emoji *gtsmodel.Emoji) error { - return e.conn. + return e.db. NewSelect(). Model(emoji). Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) @@ -459,12 +459,12 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj ) } -func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) { return e.getEmoji( ctx, "Shortcode.Domain", func(emoji *gtsmodel.Emoji) error { - q := e.conn. + q := e.db. NewSelect(). Model(emoji) @@ -483,12 +483,12 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin ) } -func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, error) { return e.getEmoji( ctx, "ImageStaticURL", func(emoji *gtsmodel.Emoji) error { - return e.conn. + return e.db. NewSelect(). Model(emoji). Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL). @@ -498,35 +498,35 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string ) } -func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error { +func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error { return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error { - _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx) - return e.conn.ProcessError(err) + _, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx) + return e.db.ProcessError(err) }) } -func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) { +func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, error) { emojiCategoryIDs := []string{} - q := e.conn. + q := e.db. NewSelect(). TableExpr("? AS ?", bun.Ident("emoji_categories"), bun.Ident("emoji_category")). Column("emoji_category.id"). Order("emoji_category.name ASC") if err := q.Scan(ctx, &emojiCategoryIDs); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } return e.GetEmojiCategoriesByIDs(ctx, emojiCategoryIDs) } -func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) { +func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, error) { return e.getEmojiCategory( ctx, "ID", func(emojiCategory *gtsmodel.EmojiCategory) error { - return e.conn. + return e.db. NewSelect(). Model(emojiCategory). Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx) @@ -535,12 +535,12 @@ func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.Em ) } -func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) { +func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) { return e.getEmojiCategory( ctx, "Name", func(emojiCategory *gtsmodel.EmojiCategory) error { - return e.conn. + return e.db. NewSelect(). Model(emojiCategory). Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx) @@ -549,14 +549,14 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts ) } -func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) { // Fetch emoji from database cache with loader callback emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) { var emoji gtsmodel.Emoji // Not cached! Perform database query if err := dbQuery(&emoji); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } return &emoji, nil @@ -580,7 +580,7 @@ func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gts return emoji, nil } -func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) { if len(emojiIDs) == 0 { return nil, db.ErrNoEntries } @@ -600,20 +600,20 @@ func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gts return emojis, nil } -func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) { +func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) { return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) { var category gtsmodel.EmojiCategory // Not cached! Perform database query if err := dbQuery(&category); err != nil { - return nil, e.conn.ProcessError(err) + return nil, e.db.ProcessError(err) } return &category, nil }, keyParts...) } -func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) { +func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) { if len(emojiCategoryIDs) == 0 { return nil, db.ErrNoEntries } diff --git a/internal/db/bundb/errors.go b/internal/db/bundb/errors.go index 6236b82d8..6bec8edae 100644 --- a/internal/db/bundb/errors.go +++ b/internal/db/bundb/errors.go @@ -18,14 +18,20 @@ package bundb import ( + "errors" + "github.com/jackc/pgconn" "github.com/superseriousbusiness/gotosocial/internal/db" "modernc.org/sqlite" sqlite3 "modernc.org/sqlite/lib" ) +// errBusy is a sentinel error indicating +// busy database (e.g. retry needed). +var errBusy = errors.New("busy") + // processPostgresError processes an error, replacing any postgres specific errors with our own error type -func processPostgresError(err error) db.Error { +func processPostgresError(err error) error { // Attempt to cast as postgres pgErr, ok := err.(*pgconn.PgError) if !ok { @@ -34,16 +40,16 @@ func processPostgresError(err error) db.Error { // Handle supplied error code: // (https://www.postgresql.org/docs/10/errcodes-appendix.html) - switch pgErr.Code { + switch pgErr.Code { //nolint case "23505" /* unique_violation */ : return db.ErrAlreadyExists - default: - return err } + + return err } // processSQLiteError processes an error, replacing any sqlite specific errors with our own error type -func processSQLiteError(err error) db.Error { +func processSQLiteError(err error) error { // Attempt to cast as sqlite sqliteErr, ok := err.(*sqlite.Error) if !ok { @@ -55,7 +61,11 @@ func processSQLiteError(err error) db.Error { case sqlite3.SQLITE_CONSTRAINT_UNIQUE, sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY: return db.ErrAlreadyExists - default: - return err + case sqlite3.SQLITE_BUSY: + return errBusy + case sqlite3.SQLITE_BUSY_TIMEOUT: + return db.ErrBusyTimeout } + + return err } diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 60d77600a..48332c731 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -34,12 +34,12 @@ import ( ) type instanceDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { - q := i.conn. +func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, error) { + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Column("account.id"). @@ -56,13 +56,13 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int count, err := q.Count(ctx) if err != nil { - return 0, i.conn.ProcessError(err) + return 0, i.db.ProcessError(err) } return count, nil } -func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { - q := i.conn. +func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, error) { + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")) @@ -78,13 +78,13 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( count, err := q.Count(ctx) if err != nil { - return 0, i.conn.ProcessError(err) + return 0, i.db.ProcessError(err) } return count, nil } -func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { - q := i.conn. +func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, error) { + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")) @@ -101,12 +101,12 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i count, err := q.Count(ctx) if err != nil { - return 0, i.conn.ProcessError(err) + return 0, i.db.ProcessError(err) } return count, nil } -func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) { +func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error) { // Normalize the domain as punycode var err error domain, err = util.Punify(domain) @@ -118,7 +118,7 @@ func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel. ctx, "Domain", func(instance *gtsmodel.Instance) error { - return i.conn.NewSelect(). + return i.db.NewSelect(). Model(instance). Where("? = ?", bun.Ident("instance.domain"), domain). Scan(ctx) @@ -132,7 +132,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel. ctx, "ID", func(instance *gtsmodel.Instance) error { - return i.conn.NewSelect(). + return i.db.NewSelect(). Model(instance). Where("? = ?", bun.Ident("instance.id"), id). Scan(ctx) @@ -141,14 +141,14 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel. ) } -func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) { +func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) { // Fetch instance from database cache with loader callback instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) { var instance gtsmodel.Instance // Not cached! Perform database query. if err := dbQuery(&instance); err != nil { - return nil, i.conn.ProcessError(err) + return nil, i.db.ProcessError(err) } return &instance, nil @@ -210,8 +210,8 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc } return i.state.Caches.GTS.Instance().Store(instance, func() error { - _, err := i.conn.NewInsert().Model(instance).Exec(ctx) - return i.conn.ProcessError(err) + _, err := i.db.NewInsert().Model(instance).Exec(ctx) + return i.db.ProcessError(err) }) } @@ -230,20 +230,20 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst } return i.state.Caches.GTS.Instance().Store(instance, func() error { - _, err := i.conn. + _, err := i.db. NewUpdate(). Model(instance). Where("? = ?", bun.Ident("instance.id"), instance.ID). Column(columns...). Exec(ctx) - return i.conn.ProcessError(err) + return i.db.ProcessError(err) }) } -func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) { +func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, error) { instanceIDs := []string{} - q := i.conn. + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). // Select just the IDs of each instance. @@ -256,7 +256,7 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool } if err := q.Scan(ctx, &instanceIDs); err != nil { - return nil, i.conn.ProcessError(err) + return nil, i.db.ProcessError(err) } if len(instanceIDs) == 0 { @@ -280,7 +280,7 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool return instances, nil } -func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { +func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -296,7 +296,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max // Make educated guess for slice size accountIDs := make([]string, 0, limit) - q := i.conn. + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). // Select just the account ID. @@ -315,7 +315,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max } if err := q.Scan(ctx, &accountIDs); err != nil { - return nil, i.conn.ProcessError(err) + return nil, i.db.ProcessError(err) } // Catch case of no accounts early. @@ -340,13 +340,13 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max return accounts, nil } -func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]string, db.Error) { +func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]string, error) { addresses := []string{} // Select email addresses of approved, confirmed, // and enabled moderators or admins. - q := i.conn. + q := i.db. NewSelect(). TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). Column("user.email"). @@ -361,7 +361,7 @@ func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]strin OrderExpr("? ASC", bun.Ident("user.email")) if err := q.Scan(ctx, &addresses); err != nil { - return nil, i.conn.ProcessError(err) + return nil, i.db.ProcessError(err) } if len(addresses) == 0 { diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 837dfac27..25bb3a65d 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -33,7 +33,7 @@ import ( ) type listDB struct { - conn *DBConn + db *WrappedDB state *state.State } @@ -46,7 +46,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er ctx, "ID", func(list *gtsmodel.List) error { - return l.conn.NewSelect(). + return l.db.NewSelect(). Model(list). Where("? = ?", bun.Ident("list.id"), id). Scan(ctx) @@ -61,7 +61,7 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo // Not cached! Perform database query. if err := dbQuery(&list); err != nil { - return nil, l.conn.ProcessError(err) + return nil, l.db.ProcessError(err) } return &list, nil @@ -86,14 +86,14 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { // Fetch IDs of all lists owned by this account. var listIDs []string - if err := l.conn. + if err := l.db. NewSelect(). TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")). Column("list.id"). Where("? = ?", bun.Ident("list.account_id"), accountID). Order("list.id DESC"). Scan(ctx, &listIDs); err != nil { - return nil, l.conn.ProcessError(err) + return nil, l.db.ProcessError(err) } if len(listIDs) == 0 { @@ -148,8 +148,8 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { return l.state.Caches.GTS.List().Store(list, func() error { - _, err := l.conn.NewInsert().Model(list).Exec(ctx) - return l.conn.ProcessError(err) + _, err := l.db.NewInsert().Model(list).Exec(ctx) + return l.db.ProcessError(err) }) } @@ -171,12 +171,12 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. }() return l.state.Caches.GTS.List().Store(list, func() error { - _, err := l.conn.NewUpdate(). + _, err := l.db.NewUpdate(). Model(list). Where("? = ?", bun.Ident("list.id"), list.ID). Column(columns...). Exec(ctx) - return l.conn.ProcessError(err) + return l.db.ProcessError(err) }) } @@ -207,7 +207,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error { } }() - return l.conn.RunInTx(ctx, func(tx bun.Tx) error { + return l.db.RunInTx(ctx, func(tx bun.Tx) error { // Delete all entries attached to list. if _, err := tx.NewDelete(). Table("list_entries"). @@ -234,7 +234,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis ctx, "ID", func(listEntry *gtsmodel.ListEntry) error { - return l.conn.NewSelect(). + return l.db.NewSelect(). Model(listEntry). Where("? = ?", bun.Ident("list_entry.id"), id). Scan(ctx) @@ -249,7 +249,7 @@ func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(* // Not cached! Perform database query. if err := dbQuery(&listEntry); err != nil { - return nil, l.conn.ProcessError(err) + return nil, l.db.ProcessError(err) } return &listEntry, nil @@ -289,7 +289,7 @@ func (l *listDB) GetListEntries(ctx context.Context, frontToBack = true ) - q := l.conn. + q := l.db. NewSelect(). TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). // Select only IDs from table @@ -329,7 +329,7 @@ func (l *listDB) GetListEntries(ctx context.Context, } if err := q.Scan(ctx, &entryIDs); err != nil { - return nil, l.conn.ProcessError(err) + return nil, l.db.ProcessError(err) } if len(entryIDs) == 0 { @@ -362,7 +362,7 @@ func (l *listDB) GetListEntries(ctx context.Context, func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { var entryIDs []string - if err := l.conn. + if err := l.db. NewSelect(). TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). // Select only IDs from table @@ -370,7 +370,7 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) // Select only entries belonging with given followID. Where("? = ?", bun.Ident("entry.follow_id"), followID). Scan(ctx, &entryIDs); err != nil { - return nil, l.conn.ProcessError(err) + return nil, l.db.ProcessError(err) } if len(entryIDs) == 0 { @@ -424,7 +424,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt }() // Finally, insert each list entry into the database. - return l.conn.RunInTx(ctx, func(tx bun.Tx) error { + return l.db.RunInTx(ctx, func(tx bun.Tx) error { for _, entry := range entries { if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { _, err := tx. @@ -468,7 +468,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { }() // Finally delete the list entry. - _, err = l.conn.NewDelete(). + _, err = l.db.NewDelete(). Table("list_entries"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) @@ -479,14 +479,14 @@ func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID stri var entryIDs []string // Fetch entry IDs for follow ID. - if err := l.conn. + if err := l.db. NewSelect(). Table("list_entries"). Column("id"). Where("? = ?", bun.Ident("follow_id"), followID). Order("id DESC"). Scan(ctx, &entryIDs); err != nil { - return l.conn.ProcessError(err) + return l.db.ProcessError(err) } for _, id := range entryIDs { diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index c190df44a..3b885af61 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -32,16 +32,16 @@ import ( ) type mediaDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) { +func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, error) { return m.getAttachment( ctx, "ID", func(attachment *gtsmodel.MediaAttachment) error { - return m.conn.NewSelect(). + return m.db.NewSelect(). Model(attachment). Where("? = ?", bun.Ident("media_attachment.id"), id). Scan(ctx) @@ -68,13 +68,13 @@ func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gts return attachments, nil } -func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) { +func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) { return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) { var attachment gtsmodel.MediaAttachment // Not cached! Perform database query if err := dbQuery(&attachment); err != nil { - return nil, m.conn.ProcessError(err) + return nil, m.db.ProcessError(err) } return &attachment, nil @@ -83,8 +83,8 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error { return m.state.Caches.GTS.Media().Store(media, func() error { - _, err := m.conn.NewInsert().Model(media).Exec(ctx) - return m.conn.ProcessError(err) + _, err := m.db.NewInsert().Model(media).Exec(ctx) + return m.db.ProcessError(err) }) } @@ -96,12 +96,12 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt } return m.state.Caches.GTS.Media().Store(media, func() error { - _, err := m.conn.NewUpdate(). + _, err := m.db.NewUpdate(). Model(media). Where("? = ?", bun.Ident("media_attachment.id"), media.ID). Column(columns...). Exec(ctx) - return m.conn.ProcessError(err) + return m.db.ProcessError(err) }) } @@ -126,7 +126,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { ) // Delete media attachment in new transaction. - err = m.conn.RunInTx(ctx, func(tx bun.Tx) error { + err = m.db.RunInTx(ctx, func(tx bun.Tx) error { if media.AccountID != "" { var account gtsmodel.Account @@ -229,11 +229,11 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID) } - return m.conn.ProcessError(err) + return m.db.ProcessError(err) } -func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) { - q := m.conn. +func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, error) { + q := m.db. NewSelect(). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). Column("media_attachment.id"). @@ -243,7 +243,7 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) count, err := q.Count(ctx) if err != nil { - return 0, m.conn.ProcessError(err) + return 0, m.db.ProcessError(err) } return count, nil @@ -252,7 +252,7 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { attachmentIDs := make([]string, 0, limit) - q := m.conn.NewSelect(). + q := m.db.NewSelect(). Table("media_attachments"). Column("id"). Order("id DESC") @@ -266,7 +266,7 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ( } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.conn.ProcessError(err) + return nil, m.db.ProcessError(err) } return m.GetAttachmentsByIDs(ctx, attachmentIDs) @@ -275,7 +275,7 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ( func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { attachmentIDs := make([]string, 0, limit) - q := m.conn.NewSelect(). + q := m.db.NewSelect(). Table("media_attachments"). Column("id"). Where("remote_url IS NOT NULL"). @@ -290,16 +290,16 @@ func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.conn.ProcessError(err) + return nil, m.db.ProcessError(err) } return m.GetAttachmentsByIDs(ctx, attachmentIDs) } -func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { +func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) { attachmentIDs := make([]string, 0, limit) - q := m.conn. + q := m.db. NewSelect(). Table("media_attachments"). Column("id"). @@ -313,16 +313,16 @@ func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan t } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.conn.ProcessError(err) + return nil, m.db.ProcessError(err) } return m.GetAttachmentsByIDs(ctx, attachmentIDs) } -func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { +func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { attachmentIDs := make([]string, 0, limit) - q := m.conn.NewSelect(). + q := m.db.NewSelect(). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). Column("media_attachment.id"). WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery { @@ -341,16 +341,16 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.conn.ProcessError(err) + return nil, m.db.ProcessError(err) } return m.GetAttachmentsByIDs(ctx, attachmentIDs) } -func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { +func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) { attachmentIDs := make([]string, 0, limit) - q := m.conn. + q := m.db. NewSelect(). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). Column("media_attachment.id"). @@ -367,14 +367,14 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.conn.ProcessError(err) + return nil, m.db.ProcessError(err) } return m.GetAttachmentsByIDs(ctx, attachmentIDs) } -func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) { - q := m.conn. +func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, error) { + q := m.db. NewSelect(). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). Column("media_attachment.id"). @@ -387,7 +387,7 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t count, err := q.Count(ctx) if err != nil { - return 0, m.conn.ProcessError(err) + return 0, m.db.ProcessError(err) } return count, nil diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 9a41eb3b8..12d71a95a 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -31,21 +31,21 @@ import ( ) type mentionDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) { mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) { var mention gtsmodel.Mention - q := m.conn. + q := m.db. NewSelect(). Model(&mention). Where("? = ?", bun.Ident("mention.id"), id) if err := q.Scan(ctx); err != nil { - return nil, m.conn.ProcessError(err) + return nil, m.db.ProcessError(err) } return &mention, nil @@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio return mention, nil } -func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { +func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) { mentions := make([]*gtsmodel.Mention, 0, len(ids)) for _, id := range ids { @@ -104,8 +104,8 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel. func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { return m.state.Caches.GTS.Mention().Store(mention, func() error { - _, err := m.conn.NewInsert().Model(mention).Exec(ctx) - return m.conn.ProcessError(err) + _, err := m.db.NewInsert().Model(mention).Exec(ctx) + return m.db.ProcessError(err) }) } @@ -125,9 +125,9 @@ func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error { } // Finally delete mention from DB. - _, err = m.conn.NewDelete(). + _, err = m.db.NewDelete(). Table("mentions"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) - return m.conn.ProcessError(err) + return m.db.ProcessError(err) } diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 277a935fd..b0757fb1e 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -31,19 +31,19 @@ import ( ) type notificationDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { +func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) { return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { var notif gtsmodel.Notification - q := n.conn.NewSelect(). + q := n.db.NewSelect(). Model(¬if). Where("? = ?", bun.Ident("notification.id"), id) if err := q.Scan(ctx); err != nil { - return nil, n.conn.ProcessError(err) + return nil, n.db.ProcessError(err) } return ¬if, nil @@ -56,11 +56,11 @@ func (n *notificationDB) GetNotification( targetAccountID string, originAccountID string, statusID string, -) (*gtsmodel.Notification, db.Error) { +) (*gtsmodel.Notification, error) { return n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { var notif gtsmodel.Notification - q := n.conn.NewSelect(). + q := n.db.NewSelect(). Model(¬if). Where("? = ?", bun.Ident("notification_type"), notificationType). Where("? = ?", bun.Ident("target_account_id"), targetAccountID). @@ -68,7 +68,7 @@ func (n *notificationDB) GetNotification( Where("? = ?", bun.Ident("status_id"), statusID) if err := q.Scan(ctx); err != nil { - return nil, n.conn.ProcessError(err) + return nil, n.db.ProcessError(err) } return ¬if, nil @@ -83,7 +83,7 @@ func (n *notificationDB) GetAccountNotifications( minID string, limit int, excludeTypes []string, -) ([]*gtsmodel.Notification, db.Error) { +) ([]*gtsmodel.Notification, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -95,7 +95,7 @@ func (n *notificationDB) GetAccountNotifications( frontToBack = true ) - q := n.conn. + q := n.db. NewSelect(). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). Column("notification.id") @@ -140,7 +140,7 @@ func (n *notificationDB) GetAccountNotifications( } if err := q.Scan(ctx, ¬ifIDs); err != nil { - return nil, n.conn.ProcessError(err) + return nil, n.db.ProcessError(err) } if len(notifIDs) == 0 { @@ -174,12 +174,12 @@ func (n *notificationDB) GetAccountNotifications( func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { return n.state.Caches.GTS.Notification().Store(notif, func() error { - _, err := n.conn.NewInsert().Model(notif).Exec(ctx) - return n.conn.ProcessError(err) + _, err := n.db.NewInsert().Model(notif).Exec(ctx) + return n.db.ProcessError(err) }) } -func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error { +func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error { defer n.state.Caches.GTS.Notification().Invalidate("ID", id) // Load notif into cache before attempting a delete, @@ -195,21 +195,21 @@ func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) } // Finally delete notif from DB. - _, err = n.conn.NewDelete(). + _, err = n.db.NewDelete(). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). Where("? = ?", bun.Ident("notification.id"), id). Exec(ctx) - return n.conn.ProcessError(err) + return n.db.ProcessError(err) } -func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error { +func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error { if targetAccountID == "" && originAccountID == "" { return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set") } var notifIDs []string - q := n.conn. + q := n.db. NewSelect(). Column("id"). Table("notifications") @@ -227,7 +227,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string } if _, err := q.Exec(ctx, ¬ifIDs); err != nil { - return n.conn.ProcessError(err) + return n.db.ProcessError(err) } defer func() { @@ -248,24 +248,24 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string } // Finally delete all from DB. - _, err := n.conn.NewDelete(). + _, err := n.db.NewDelete(). Table("notifications"). Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). Exec(ctx) - return n.conn.ProcessError(err) + return n.db.ProcessError(err) } -func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error { +func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) error { var notifIDs []string - q := n.conn. + q := n.db. NewSelect(). Column("id"). Table("notifications"). Where("? = ?", bun.Ident("status_id"), statusID) if _, err := q.Exec(ctx, ¬ifIDs); err != nil { - return n.conn.ProcessError(err) + return n.db.ProcessError(err) } defer func() { @@ -286,9 +286,9 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu } // Finally delete all from DB. - _, err := n.conn.NewDelete(). + _, err := n.db.NewDelete(). Table("notifications"). Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). Exec(ctx) - return n.conn.ProcessError(err) + return n.db.ProcessError(err) } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 82559a213..c865f8aad 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -30,11 +30,11 @@ import ( ) type relationshipDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { +func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) { var rel gtsmodel.Relationship rel.ID = targetAccount @@ -90,91 +90,91 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { var followIDs []string - if err := newSelectFollows(r.conn, accountID). + if err := newSelectFollows(r.db, accountID). Scan(ctx, &followIDs); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { var followIDs []string - if err := newSelectLocalFollows(r.conn, accountID). + if err := newSelectLocalFollows(r.db, accountID). Scan(ctx, &followIDs); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { var followIDs []string - if err := newSelectFollowers(r.conn, accountID). + if err := newSelectFollowers(r.db, accountID). Scan(ctx, &followIDs); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { var followIDs []string - if err := newSelectLocalFollowers(r.conn, accountID). + if err := newSelectLocalFollowers(r.db, accountID). Scan(ctx, &followIDs); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollows(r.conn, accountID).Count(ctx) - return n, r.conn.ProcessError(err) + n, err := newSelectFollows(r.db, accountID).Count(ctx) + return n, r.db.ProcessError(err) } func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { - n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx) - return n, r.conn.ProcessError(err) + n, err := newSelectLocalFollows(r.db, accountID).Count(ctx) + return n, r.db.ProcessError(err) } func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowers(r.conn, accountID).Count(ctx) - return n, r.conn.ProcessError(err) + n, err := newSelectFollowers(r.db, accountID).Count(ctx) + return n, r.db.ProcessError(err) } func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { - n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx) - return n, r.conn.ProcessError(err) + n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx) + return n, r.db.ProcessError(err) } func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { var followReqIDs []string - if err := newSelectFollowRequests(r.conn, accountID). + if err := newSelectFollowRequests(r.db, accountID). Scan(ctx, &followReqIDs); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return r.GetFollowRequestsByIDs(ctx, followReqIDs) } func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { var followReqIDs []string - if err := newSelectFollowRequesting(r.conn, accountID). + if err := newSelectFollowRequesting(r.db, accountID). Scan(ctx, &followReqIDs); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return r.GetFollowRequestsByIDs(ctx, followReqIDs) } func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx) - return n, r.conn.ProcessError(err) + n, err := newSelectFollowRequests(r.db, accountID).Count(ctx) + return n, r.db.ProcessError(err) } func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx) - return n, r.conn.ProcessError(err) + n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx) + return n, r.db.ProcessError(err) } // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. -func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery { - return conn.NewSelect(). +func newSelectFollowRequests(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). Where("? = ?", bun.Ident("target_account_id"), accountID). @@ -182,8 +182,8 @@ func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery { } // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. -func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery { - return conn.NewSelect(). +func newSelectFollowRequesting(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). Where("? = ?", bun.Ident("target_account_id"), accountID). @@ -191,8 +191,8 @@ func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery } // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. -func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery { - return conn.NewSelect(). +func newSelectFollows(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). Table("follows"). Column("id"). Where("? = ?", bun.Ident("account_id"), accountID). @@ -201,15 +201,15 @@ func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery { // newSelectLocalFollows returns a new select query for all rows in the follows table with // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). -func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery { - return conn.NewSelect(). +func newSelectLocalFollows(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). Table("follows"). Column("id"). Where("? = ? AND ? IN (?)", bun.Ident("account_id"), accountID, bun.Ident("target_account_id"), - conn.NewSelect(). + db.NewSelect(). Table("accounts"). Column("id"). Where("? IS NULL", bun.Ident("domain")), @@ -218,8 +218,8 @@ func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery { } // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. -func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery { - return conn.NewSelect(). +func newSelectFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). Table("follows"). Column("id"). Where("? = ?", bun.Ident("target_account_id"), accountID). @@ -228,15 +228,15 @@ func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery { // newSelectLocalFollowers returns a new select query for all rows in the follows table with // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). -func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery { - return conn.NewSelect(). +func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). Table("follows"). Column("id"). Where("? = ? AND ? IN (?)", bun.Ident("target_account_id"), accountID, bun.Ident("account_id"), - conn.NewSelect(). + db.NewSelect(). Table("accounts"). Column("id"). Where("? IS NULL", bun.Ident("domain")), diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index fa68a2e97..948e82fcb 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -28,7 +28,7 @@ import ( "github.com/uptrace/bun" ) -func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { +func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) { block, err := r.GetBlock( gtscontext.SetBarebones(ctx), sourceAccountID, @@ -61,7 +61,7 @@ func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel ctx, "ID", func(block *gtsmodel.Block) error { - return r.conn.NewSelect().Model(block). + return r.db.NewSelect().Model(block). Where("? = ?", bun.Ident("block.id"), id). Scan(ctx) }, @@ -74,7 +74,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod ctx, "URI", func(block *gtsmodel.Block) error { - return r.conn.NewSelect().Model(block). + return r.db.NewSelect().Model(block). Where("? = ?", bun.Ident("block.uri"), uri). Scan(ctx) }, @@ -87,7 +87,7 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t ctx, "AccountID.TargetAccountID", func(block *gtsmodel.Block) error { - return r.conn.NewSelect().Model(block). + return r.db.NewSelect().Model(block). Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID). Scan(ctx) @@ -104,7 +104,7 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu // Not cached! Perform database query if err := dbQuery(&block); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return &block, nil @@ -142,8 +142,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { return r.state.Caches.GTS.Block().Store(block, func() error { - _, err := r.conn.NewInsert().Model(block).Exec(ctx) - return r.conn.ProcessError(err) + _, err := r.db.NewInsert().Model(block).Exec(ctx) + return r.db.ProcessError(err) }) } @@ -163,11 +163,11 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { } // Finally delete block from DB. - _, err = r.conn.NewDelete(). + _, err = r.db.NewDelete(). Table("blocks"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { @@ -186,18 +186,18 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error } // Finally delete block from DB. - _, err = r.conn.NewDelete(). + _, err = r.db.NewDelete(). Table("blocks"). Where("? = ?", bun.Ident("uri"), uri). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error { var blockIDs []string // Get full list of IDs. - if err := r.conn.NewSelect(). + if err := r.db.NewSelect(). Column("id"). Table("blocks"). WhereOr("? = ? OR ? = ?", @@ -207,7 +207,7 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri accountID, ). Scan(ctx, &blockIDs); err != nil { - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } defer func() { @@ -228,9 +228,9 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri } // Finally delete all from DB. - _, err := r.conn.NewDelete(). + _, err := r.db.NewDelete(). Table("blocks"). Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 349c1ef43..84501b0be 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -36,7 +36,7 @@ func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmode ctx, "ID", func(follow *gtsmodel.Follow) error { - return r.conn.NewSelect(). + return r.db.NewSelect(). Model(follow). Where("? = ?", bun.Ident("id"), id). Scan(ctx) @@ -50,7 +50,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo ctx, "URI", func(follow *gtsmodel.Follow) error { - return r.conn.NewSelect(). + return r.db.NewSelect(). Model(follow). Where("? = ?", bun.Ident("uri"), uri). Scan(ctx) @@ -64,7 +64,7 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, ctx, "AccountID.TargetAccountID", func(follow *gtsmodel.Follow) error { - return r.conn.NewSelect(). + return r.db.NewSelect(). Model(follow). Where("? = ?", bun.Ident("account_id"), sourceAccountID). Where("? = ?", bun.Ident("target_account_id"), targetAccountID). @@ -94,7 +94,7 @@ func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]* return follows, nil } -func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { +func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) { follow, err := r.GetFollow( gtscontext.SetBarebones(ctx), sourceAccountID, @@ -106,7 +106,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string return (follow != nil), nil } -func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) { +func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, error) { // make sure account 1 follows account 2 f1, err := r.IsFollowing(ctx, accountID1, @@ -135,7 +135,7 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f // Not cached! Perform database query if err := dbQuery(&follow); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return &follow, nil @@ -190,8 +190,8 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { return r.state.Caches.GTS.Follow().Store(follow, func() error { - _, err := r.conn.NewInsert().Model(follow).Exec(ctx) - return r.conn.ProcessError(err) + _, err := r.db.NewInsert().Model(follow).Exec(ctx) + return r.db.ProcessError(err) }) } @@ -203,12 +203,12 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll } return r.state.Caches.GTS.Follow().Store(follow, func() error { - if _, err := r.conn.NewUpdate(). + if _, err := r.db.NewUpdate(). Model(follow). Where("? = ?", bun.Ident("follow.id"), follow.ID). Column(columns...). Exec(ctx); err != nil { - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } return nil @@ -217,11 +217,11 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { // Delete the follow itself using the given ID. - if _, err := r.conn.NewDelete(). + if _, err := r.db.NewDelete(). Table("follows"). Where("? = ?", bun.Ident("id"), id). Exec(ctx); err != nil { - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } // Delete every list entry that used this followID. @@ -297,7 +297,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str var followIDs []string // Get full list of IDs. - if _, err := r.conn. + if _, err := r.db. NewSelect(). Column("id"). Table("follows"). @@ -308,7 +308,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str accountID, ). Exec(ctx, &followIDs); err != nil { - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } defer func() { diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index f233bdd58..a6e913953 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -35,7 +35,7 @@ func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (* ctx, "ID", func(followReq *gtsmodel.FollowRequest) error { - return r.conn.NewSelect(). + return r.db.NewSelect(). Model(followReq). Where("? = ?", bun.Ident("id"), id). Scan(ctx) @@ -49,7 +49,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) ctx, "URI", func(followReq *gtsmodel.FollowRequest) error { - return r.conn.NewSelect(). + return r.db.NewSelect(). Model(followReq). Where("? = ?", bun.Ident("uri"), uri). Scan(ctx) @@ -63,7 +63,7 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s ctx, "AccountID.TargetAccountID", func(followReq *gtsmodel.FollowRequest) error { - return r.conn.NewSelect(). + return r.db.NewSelect(). Model(followReq). Where("? = ?", bun.Ident("account_id"), sourceAccountID). Where("? = ?", bun.Ident("target_account_id"), targetAccountID). @@ -93,7 +93,7 @@ func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []strin return followReqs, nil } -func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { +func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) { followReq, err := r.GetFollowRequest( gtscontext.SetBarebones(ctx), sourceAccountID, @@ -112,7 +112,7 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db // Not cached! Perform database query if err := dbQuery(&followReq); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return &followReq, nil @@ -150,8 +150,8 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { - _, err := r.conn.NewInsert().Model(follow).Exec(ctx) - return r.conn.ProcessError(err) + _, err := r.db.NewInsert().Model(follow).Exec(ctx) + return r.db.ProcessError(err) }) } @@ -163,19 +163,19 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest } return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error { - if _, err := r.conn.NewUpdate(). + if _, err := r.db.NewUpdate(). Model(followRequest). Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). Column(columns...). Exec(ctx); err != nil { - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } return nil }) } -func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { +func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { // Get original follow request. followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID) if err != nil { @@ -198,12 +198,12 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI if err := r.state.Caches.GTS.Follow().Store(follow, func() error { // If the follow already exists, just // replace the URI with the new one. - _, err := r.conn. + _, err := r.db. NewInsert(). Model(follow). On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) }); err != nil { return nil, err } @@ -212,12 +212,12 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) // Delete original follow request. - if _, err := r.conn. + if _, err := r.db. NewDelete(). Table("follow_requests"). Where("? = ?", bun.Ident("id"), followReq.ID). Exec(ctx); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } // Delete original follow request notification @@ -230,7 +230,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI return follow, nil } -func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error { +func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error { // Delete follow request first. if err := r.DeleteFollowRequest(ctx, sourceAccountID, targetAccountID); err != nil { return err @@ -262,11 +262,11 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI } // Finally delete followreq from DB. - _, err = r.conn.NewDelete(). + _, err = r.db.NewDelete(). Table("follow_requests"). Where("? = ?", bun.Ident("id"), follow.ID). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { @@ -285,11 +285,11 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) } // Finally delete followreq from DB. - _, err = r.conn.NewDelete(). + _, err = r.db.NewDelete(). Table("follow_requests"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { @@ -308,18 +308,18 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin } // Finally delete followreq from DB. - _, err = r.conn.NewDelete(). + _, err = r.db.NewDelete(). Table("follow_requests"). Where("? = ?", bun.Ident("uri"), uri). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error { var followReqIDs []string // Get full list of IDs. - if _, err := r.conn. + if _, err := r.db. NewSelect(). Column("id"). Table("follow_requestss"). @@ -330,7 +330,7 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun accountID, ). Exec(ctx, &followReqIDs); err != nil { - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } defer func() { @@ -351,9 +351,9 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun } // Finally delete all from DB. - _, err := r.conn.NewDelete(). + _, err := r.db.NewDelete(). Table("follow_requests"). Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go index ee8aa1cb3..3a1e18789 100644 --- a/internal/db/bundb/report.go +++ b/internal/db/bundb/report.go @@ -32,15 +32,15 @@ import ( ) type reportDB struct { - conn *DBConn + db *WrappedDB state *state.State } func (r *reportDB) newReportQ(report interface{}) *bun.SelectQuery { - return r.conn.NewSelect().Model(report) + return r.db.NewSelect().Model(report) } -func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, db.Error) { +func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, error) { return r.getReport( ctx, "ID", @@ -51,10 +51,10 @@ func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Repo ) } -func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, db.Error) { +func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, error) { reportIDs := []string{} - q := r.conn. + q := r.db. NewSelect(). TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). Column("report.id"). @@ -94,7 +94,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str } if err := q.Scan(ctx, &reportIDs); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } // Catch case of no reports early @@ -118,14 +118,14 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str return reports, nil } -func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, db.Error) { +func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) { // Fetch report from database cache with loader callback report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) { var report gtsmodel.Report // Not cached! Perform database query if err := dbQuery(&report); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } return &report, nil @@ -166,34 +166,34 @@ func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*g return report, nil } -func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) db.Error { +func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error { return r.state.Caches.GTS.Report().Store(report, func() error { - _, err := r.conn.NewInsert().Model(report).Exec(ctx) - return r.conn.ProcessError(err) + _, err := r.db.NewInsert().Model(report).Exec(ctx) + return r.db.ProcessError(err) }) } -func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, db.Error) { +func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, error) { // Update the report's last-updated report.UpdatedAt = time.Now() if len(columns) != 0 { columns = append(columns, "updated_at") } - if _, err := r.conn. + if _, err := r.db. NewUpdate(). Model(report). Where("? = ?", bun.Ident("report.id"), report.ID). Column(columns...). Exec(ctx); err != nil { - return nil, r.conn.ProcessError(err) + return nil, r.db.ProcessError(err) } r.state.Caches.GTS.Report().Invalidate("ID", report.ID) return report, nil } -func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error { +func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error { defer r.state.Caches.GTS.Report().Invalidate("ID", id) // Load status into cache before attempting a delete, @@ -209,9 +209,9 @@ func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error { } // Finally delete report from DB. - _, err = r.conn.NewDelete(). + _, err = r.db.NewDelete(). TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). Where("? = ?", bun.Ident("report.id"), id). Exec(ctx) - return r.conn.ProcessError(err) + return r.db.ProcessError(err) } diff --git a/internal/db/bundb/search.go b/internal/db/bundb/search.go index 1d7eefd48..f4e41d0f4 100644 --- a/internal/db/bundb/search.go +++ b/internal/db/bundb/search.go @@ -56,7 +56,7 @@ import ( // This isn't ideal, of course, but at least we could cover the most common use case of // a caller paging down through results. type searchDB struct { - conn *DBConn + db *WrappedDB state *state.State } @@ -89,7 +89,7 @@ func (s *searchDB) SearchForAccounts( frontToBack = true ) - q := s.conn. + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). // Select only IDs from table. @@ -148,7 +148,7 @@ func (s *searchDB) SearchForAccounts( } if err := q.Scan(ctx, &accountIDs); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } if len(accountIDs) == 0 { @@ -183,7 +183,7 @@ func (s *searchDB) SearchForAccounts( // followedAccounts returns a subquery that selects only IDs // of accounts that are followed by the given accountID. func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery { - return s.conn. + return s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). Column("follow.target_account_id"). @@ -196,7 +196,7 @@ func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery { // in the concatenation. func (s *searchDB) accountText(following bool) *bun.SelectQuery { var ( - accountText = s.conn.NewSelect() + accountText = s.db.NewSelect() query string args []interface{} ) @@ -225,7 +225,7 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery { // different number of placeholders depending on // following/not following. COALESCE calls ensure // that we're not trying to concatenate null values. - d := s.conn.Dialect().Name() + d := s.db.Dialect().Name() switch { case d == dialect.SQLite && following: @@ -276,7 +276,7 @@ func (s *searchDB) SearchForStatuses( frontToBack = true ) - q := s.conn. + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table @@ -326,7 +326,7 @@ func (s *searchDB) SearchForStatuses( } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } if len(statusIDs) == 0 { @@ -361,11 +361,11 @@ func (s *searchDB) SearchForStatuses( // statusText returns a subquery that selects a concatenation // of status content and content warning as "status_text". func (s *searchDB) statusText() *bun.SelectQuery { - statusText := s.conn.NewSelect() + statusText := s.db.NewSelect() // SQLite and Postgres use different // syntaxes for concatenation. - switch s.conn.Dialect().Name() { + switch s.db.Dialect().Name() { case dialect.SQLite: statusText = statusText.ColumnExpr( diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go index 9a4256de5..8d778ffa2 100644 --- a/internal/db/bundb/session.go +++ b/internal/db/bundb/session.go @@ -22,26 +22,25 @@ import ( "crypto/rand" "io" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" ) type sessionDB struct { - conn *DBConn + db *WrappedDB } -func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { +func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) { rss := make([]*gtsmodel.RouterSession, 0, 1) // get the first router session in the db or... - if err := s.conn. + if err := s.db. NewSelect(). Model(&rss). Limit(1). Order("router_session.id DESC"). Scan(ctx); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } // ... create a new one @@ -52,7 +51,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db return rss[0], nil } -func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { +func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, error) { buf := make([]byte, 64) auth := buf[:32] crypt := buf[32:64] @@ -67,11 +66,11 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, Crypt: crypt, } - if _, err := s.conn. + if _, err := s.db. NewInsert(). Model(rs). Exec(ctx); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } return rs, nil diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index ccfc9fd4b..a019216d0 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -35,19 +35,19 @@ import ( ) type statusDB struct { - conn *DBConn + db *WrappedDB state *state.State } func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { - return s.conn. + return s.db. NewSelect(). Model(status). Relation("Tags"). Relation("CreatedWithApplication") } -func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { +func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) { return s.getStatus( ctx, "ID", @@ -76,7 +76,7 @@ func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmo return statuses, nil } -func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { +func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error) { return s.getStatus( ctx, "URI", @@ -87,7 +87,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St ) } -func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { +func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, error) { return s.getStatus( ctx, "URL", @@ -98,14 +98,14 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St ) } -func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) { +func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) { // Fetch status from database cache with loader callback status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { var status gtsmodel.Status // Not cached! Perform database query. if err := dbQuery(&status); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } return &status, nil @@ -243,12 +243,12 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) return errs.Combine() } -func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { +func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { return s.state.Caches.GTS.Status().Store(status, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return s.conn.RunInTx(ctx, func(tx bun.Tx) error { + return s.db.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx. @@ -259,7 +259,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")). Exec(ctx); err != nil { - err = s.conn.ProcessError(err) + err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -276,7 +276,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")). Exec(ctx); err != nil { - err = s.conn.ProcessError(err) + err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -292,7 +292,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er Model(a). Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - err = s.conn.ProcessError(err) + err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -306,7 +306,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er }) } -func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) db.Error { +func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) error { status.UpdatedAt = time.Now() if len(columns) > 0 { // If we're updating by column, ensure "updated_at" is included. @@ -317,7 +317,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return s.conn.RunInTx(ctx, func(tx bun.Tx) error { + return s.db.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx. @@ -328,7 +328,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")). Exec(ctx); err != nil { - err = s.conn.ProcessError(err) + err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -345,7 +345,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")). Exec(ctx); err != nil { - err = s.conn.ProcessError(err) + err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -361,7 +361,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co Model(a). Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - err = s.conn.ProcessError(err) + err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -380,7 +380,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co }) } -func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { +func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { defer s.state.Caches.GTS.Status().Invalidate("ID", id) // Load status into cache before attempting a delete, @@ -397,7 +397,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { return err } - return s.conn.RunInTx(ctx, func(tx bun.Tx) error { + return s.db.RunInTx(ctx, func(tx bun.Tx) error { // delete links between this status and any emojis it uses if _, err := tx. NewDelete(). @@ -433,7 +433,7 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([ var statusIDs []string // Create SELECT status query. - q := s.conn.NewSelect(). + q := s.db.NewSelect(). Table("statuses"). Column("id") @@ -450,14 +450,14 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([ // Execute the query, scanning destination into statusIDs. if _, err := q.Exec(ctx, &statusIDs); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } // Convert status IDs into status objects. return s.GetStatusesByIDs(ctx, statusIDs) } -func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { +func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) { if onlyDirect { // Only want the direct parent, no further than first level parent, err := s.GetStatusByID(ctx, status.InReplyToID) @@ -485,7 +485,7 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status return parents, nil } -func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { +func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) { foundStatuses := &list.List{} foundStatuses.PushFront(status) s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) @@ -509,7 +509,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { var childIDs []string - q := s.conn. + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Column("status.id"). @@ -554,71 +554,71 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, } } -func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { - return s.conn. +func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error) { + return s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID). Count(ctx) } -func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { - return s.conn. +func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error) { + return s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). Count(ctx) } -func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { - return s.conn. +func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error) { + return s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). Count(ctx) } -func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { - q := s.conn. +func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). Where("? = ?", bun.Ident("status_fave.account_id"), accountID) - return s.conn.Exists(ctx, q) + return s.db.Exists(ctx, q) } -func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { - q := s.conn. +func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). Where("? = ?", bun.Ident("status.account_id"), accountID) - return s.conn.Exists(ctx, q) + return s.db.Exists(ctx, q) } -func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { - q := s.conn. +func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")). Where("? = ?", bun.Ident("status_mute.status_id"), status.ID). Where("? = ?", bun.Ident("status_mute.account_id"), accountID) - return s.conn.Exists(ctx, q) + return s.db.Exists(ctx, q) } -func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { - q := s.conn. +func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) { + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID) - return s.conn.Exists(ctx, q) + return s.db.Exists(ctx, q) } -func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { +func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) { reblogs := []*gtsmodel.Status{} q := s. @@ -626,7 +626,7 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status Where("? = ?", bun.Ident("status.boost_of_id"), status.ID) if err := q.Scan(ctx); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } return reblogs, nil } diff --git a/internal/db/bundb/statusbookmark.go b/internal/db/bundb/statusbookmark.go index f7d19cd23..8a3c4dad6 100644 --- a/internal/db/bundb/statusbookmark.go +++ b/internal/db/bundb/statusbookmark.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -30,20 +29,20 @@ import ( ) type statusBookmarkDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, db.Error) { +func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, error) { bookmark := new(gtsmodel.StatusBookmark) - err := s.conn. + err := s.db. NewSelect(). Model(bookmark). Where("? = ?", bun.Ident("status_bookmark.id"), id). Scan(ctx) if err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } bookmark.Account, err = s.state.DB.GetAccountByID(ctx, bookmark.AccountID) @@ -64,10 +63,10 @@ func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*g return bookmark, nil } -func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, db.Error) { +func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, error) { var id string - q := s.conn. + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). Column("status_bookmark.id"). @@ -76,13 +75,13 @@ func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID st Limit(1) if err := q.Scan(ctx, &id); err != nil { - return "", s.conn.ProcessError(err) + return "", s.db.ProcessError(err) } return id, nil } -func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, db.Error) { +func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -91,7 +90,7 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str // Guess size of IDs based on limit. ids := make([]string, 0, limit) - q := s.conn. + q := s.db. NewSelect(). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). Column("status_bookmark.id"). @@ -115,7 +114,7 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str } if err := q.Scan(ctx, &ids); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } bookmarks := make([]*gtsmodel.StatusBookmark, 0, len(ids)) @@ -133,26 +132,26 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str return bookmarks, nil } -func (s *statusBookmarkDB) PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) db.Error { - _, err := s.conn. +func (s *statusBookmarkDB) PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) error { + _, err := s.db. NewInsert(). Model(statusBookmark). Exec(ctx) - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } -func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) db.Error { - _, err := s.conn. +func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) error { + _, err := s.db. NewDelete(). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). Where("? = ?", bun.Ident("status_bookmark.id"), id). Exec(ctx) - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } -func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) db.Error { +func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) error { if targetAccountID == "" && originAccountID == "" { return errors.New("DeleteBookmarks: one of targetAccountID or originAccountID must be set") } @@ -161,7 +160,7 @@ func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAcco // statement (when bookmarks have a cache), // + use the IDs to invalidate cache entries. - q := s.conn. + q := s.db. NewDelete(). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")) @@ -174,24 +173,24 @@ func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAcco } if _, err := q.Exec(ctx); err != nil { - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } return nil } -func (s *statusBookmarkDB) DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) db.Error { +func (s *statusBookmarkDB) DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) error { // TODO: Capture bookmark IDs in a RETURNING // statement (when bookmarks have a cache), // + use the IDs to invalidate cache entries. - q := s.conn. + q := s.db. NewDelete(). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). Where("? = ?", bun.Ident("status_bookmark.status_id"), statusID) if _, err := q.Exec(ctx); err != nil { - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } return nil diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index 1c96a1dd0..a8d1cd0d1 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -32,16 +32,16 @@ import ( ) type statusFaveDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) { +func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) { return s.getStatusFave( ctx, "AccountID.StatusID", func(fave *gtsmodel.StatusFave) error { - return s.conn. + return s.db. NewSelect(). Model(fave). Where("? = ?", bun.Ident("account_id"), accountID). @@ -53,12 +53,12 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, stat ) } -func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) { +func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, error) { return s.getStatusFave( ctx, "ID", func(fave *gtsmodel.StatusFave) error { - return s.conn. + return s.db. NewSelect(). Model(fave). Where("? = ?", bun.Ident("id"), id). @@ -75,7 +75,7 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery // Not cached! Perform database query. if err := dbQuery(&fave); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } return &fave, nil @@ -119,16 +119,16 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery return fave, nil } -func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) { +func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) { ids := []string{} - if err := s.conn. + if err := s.db. NewSelect(). Table("status_faves"). Column("id"). Where("? = ?", bun.Ident("status_id"), statusID). Scan(ctx, &ids); err != nil { - return nil, s.conn.ProcessError(err) + return nil, s.db.ProcessError(err) } faves := make([]*gtsmodel.StatusFave, 0, len(ids)) @@ -188,17 +188,17 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo return errs.Combine() } -func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error { +func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error { return s.state.Caches.GTS.StatusFave().Store(fave, func() error { - _, err := s.conn. + _, err := s.db. NewInsert(). Model(fave). Exec(ctx) - return s.conn.ProcessError(err) + return s.db.ProcessError(err) }) } -func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error { +func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) error { defer s.state.Caches.GTS.StatusFave().Invalidate("ID", id) // Load fave into cache before attempting a delete, @@ -214,21 +214,21 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.E } // Finally delete fave from DB. - _, err = s.conn.NewDelete(). + _, err = s.db.NewDelete(). Table("status_faves"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } -func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error { +func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error { if targetAccountID == "" && originAccountID == "" { return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set") } var faveIDs []string - q := s.conn. + q := s.db. NewSelect(). Column("id"). Table("status_faves") @@ -242,7 +242,7 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st } if _, err := q.Exec(ctx, &faveIDs); err != nil { - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } defer func() { @@ -263,24 +263,24 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st } // Finally delete all from DB. - _, err := s.conn.NewDelete(). + _, err := s.db.NewDelete(). Table("status_faves"). Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). Exec(ctx) - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } -func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error { +func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) error { // Capture fave IDs in a RETURNING statement. var faveIDs []string - q := s.conn. + q := s.db. NewSelect(). Column("id"). Table("status_faves"). Where("? = ?", bun.Ident("status_id"), statusID) if _, err := q.Exec(ctx, &faveIDs); err != nil { - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } defer func() { @@ -301,9 +301,9 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID } // Finally delete all from DB. - _, err := s.conn.NewDelete(). + _, err := s.db.NewDelete(). Table("status_faves"). Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). Exec(ctx) - return s.conn.ProcessError(err) + return s.db.ProcessError(err) } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index 0859f3ff3..6aa4989d9 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -33,11 +33,11 @@ import ( ) type timelineDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { +func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -49,7 +49,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI frontToBack = true ) - q := t.conn. + q := t.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table @@ -103,7 +103,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI // Subquery to select target (followed) account // IDs from follows owned by given accountID. - subQ := t.conn. + subQ := t.db. NewSelect(). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). Column("follow.target_account_id"). @@ -119,7 +119,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI }) if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, t.conn.ProcessError(err) + return nil, t.db.ProcessError(err) } if len(statusIDs) == 0 { @@ -151,7 +151,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI return statuses, nil } -func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { +func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -160,7 +160,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI // Make educated guess for slice size statusIDs := make([]string, 0, limit) - q := t.conn. + q := t.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Column("status.id"). @@ -202,7 +202,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, t.conn.ProcessError(err) + return nil, t.db.ProcessError(err) } statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) @@ -224,7 +224,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. -func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { +func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -233,7 +233,7 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max // Make educated guess for slice size faves := make([]*gtsmodel.StatusFave, 0, limit) - fq := t.conn. + fq := t.db. NewSelect(). Model(&faves). Where("? = ?", bun.Ident("status_fave.account_id"), accountID). @@ -253,7 +253,7 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max err := fq.Scan(ctx) if err != nil { - return nil, "", "", t.conn.ProcessError(err) + return nil, "", "", t.db.ProcessError(err) } if len(faves) == 0 { @@ -322,7 +322,7 @@ func (t *timelineDB) GetListTimeline( } // Select target account IDs from follows. - subQ := t.conn. + subQ := t.db. NewSelect(). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). Column("follow.target_account_id"). @@ -330,7 +330,7 @@ func (t *timelineDB) GetListTimeline( // Select only status IDs created // by one of the followed accounts. - q := t.conn. + q := t.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table @@ -379,7 +379,7 @@ func (t *timelineDB) GetListTimeline( } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, t.conn.ProcessError(err) + return nil, t.db.ProcessError(err) } if len(statusIDs) == 0 { diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index 668bde1af..0050e6531 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -27,28 +27,28 @@ import ( ) type tombstoneDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, db.Error) { +func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) { return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) { var tomb gtsmodel.Tombstone - q := t.conn. + q := t.db. NewSelect(). Model(&tomb). Where("? = ?", bun.Ident("tombstone.uri"), uri) if err := q.Scan(ctx); err != nil { - return nil, t.conn.ProcessError(err) + return nil, t.db.ProcessError(err) } return &tomb, nil }, uri) } -func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (bool, db.Error) { +func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (bool, error) { tomb, err := t.GetTombstoneByURI(ctx, uri) if err == db.ErrNoEntries { err = nil @@ -56,23 +56,23 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b return (tomb != nil), err } -func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) db.Error { +func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error { return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error { - _, err := t.conn. + _, err := t.db. NewInsert(). Model(tombstone). Exec(ctx) - return t.conn.ProcessError(err) + return t.db.ProcessError(err) }) } -func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) db.Error { +func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error { defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id) // Delete tombstone from DB. - _, err := t.conn.NewDelete(). + _, err := t.db.NewDelete(). TableExpr("? AS ?", bun.Ident("tombstones"), bun.Ident("tombstone")). Where("? = ?", bun.Ident("tombstone.id"), id). Exec(ctx) - return t.conn.ProcessError(err) + return t.db.ProcessError(err) } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index c2ea5a67d..4b38d48fa 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -30,125 +30,125 @@ import ( ) type userDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { +func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) { var user gtsmodel.User - q := u.conn. + q := u.db. NewSelect(). Model(&user). Relation("Account"). Where("? = ?", bun.Ident("user.id"), id) if err := q.Scan(ctx); err != nil { - return nil, u.conn.ProcessError(err) + return nil, u.db.ProcessError(err) } return &user, nil }, id) } -func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { +func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) { var user gtsmodel.User - q := u.conn. + q := u.db. NewSelect(). Model(&user). Relation("Account"). Where("? = ?", bun.Ident("user.account_id"), accountID) if err := q.Scan(ctx); err != nil { - return nil, u.conn.ProcessError(err) + return nil, u.db.ProcessError(err) } return &user, nil }, accountID) } -func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { +func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) { var user gtsmodel.User - q := u.conn. + q := u.db. NewSelect(). Model(&user). Relation("Account"). Where("? = ?", bun.Ident("user.email"), emailAddress) if err := q.Scan(ctx); err != nil { - return nil, u.conn.ProcessError(err) + return nil, u.db.ProcessError(err) } return &user, nil }, emailAddress) } -func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { +func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) { var user gtsmodel.User - q := u.conn. + q := u.db. NewSelect(). Model(&user). Relation("Account"). Where("? = ?", bun.Ident("user.external_id"), id) if err := q.Scan(ctx); err != nil { - return nil, u.conn.ProcessError(err) + return nil, u.db.ProcessError(err) } return &user, nil }, id) } -func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { +func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) { var user gtsmodel.User - q := u.conn. + q := u.db. NewSelect(). Model(&user). Relation("Account"). Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken) if err := q.Scan(ctx); err != nil { - return nil, u.conn.ProcessError(err) + return nil, u.db.ProcessError(err) } return &user, nil }, confirmationToken) } -func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, db.Error) { +func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { var users []*gtsmodel.User - q := u.conn. + q := u.db. NewSelect(). Model(&users). Relation("Account") if err := q.Scan(ctx); err != nil { - return nil, u.conn.ProcessError(err) + return nil, u.db.ProcessError(err) } return users, nil } -func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error { +func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { return u.state.Caches.GTS.User().Store(user, func() error { - _, err := u.conn. + _, err := u.db. NewInsert(). Model(user). Exec(ctx) - return u.conn.ProcessError(err) + return u.db.ProcessError(err) }) } -func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) db.Error { +func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) error { // Update the user's last-updated user.UpdatedAt = time.Now() @@ -158,17 +158,17 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. } return u.state.Caches.GTS.User().Store(user, func() error { - _, err := u.conn. + _, err := u.db. NewUpdate(). Model(user). Where("? = ?", bun.Ident("user.id"), user.ID). Column(columns...). Exec(ctx) - return u.conn.ProcessError(err) + return u.db.ProcessError(err) }) } -func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { +func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error { defer u.state.Caches.GTS.User().Invalidate("ID", userID) // Load user into cache before attempting a delete, @@ -184,9 +184,9 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { } // Finally delete user from DB. - _, err = u.conn.NewDelete(). + _, err = u.db.NewDelete(). TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). Where("? = ?", bun.Ident("user.id"), userID). Exec(ctx) - return u.conn.ProcessError(err) + return u.db.ProcessError(err) } diff --git a/internal/db/bundb/wrap.go b/internal/db/bundb/wrap.go new file mode 100644 index 000000000..a5039914a --- /dev/null +++ b/internal/db/bundb/wrap.go @@ -0,0 +1,258 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "database/sql" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" +) + +// WrappedDB wraps a bun database instance +// to provide common per-dialect SQL error +// conversions to common types, and retries +// on returned busy errors (SQLite only for now). +type WrappedDB struct { + errHook func(error) error + *bun.DB // underlying conn +} + +// WrapDB wraps a bun database instance in our own WrappedDB type. +func WrapDB(db *bun.DB) *WrappedDB { + var errProc func(error) error + switch name := db.Dialect().Name(); name { + case dialect.PG: + errProc = processPostgresError + case dialect.SQLite: + errProc = processSQLiteError + default: + panic("unknown dialect name: " + name.String()) + } + return &WrappedDB{ + errHook: errProc, + DB: db, + } +} + +// BeginTx wraps bun.DB.BeginTx() with retry-busy timeout. +func (db *WrappedDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx bun.Tx, err error) { + err = retryOnBusy(ctx, func() error { + tx, err = db.DB.BeginTx(ctx, opts) + err = db.ProcessError(err) + return err + }) + return +} + +// ExecContext wraps bun.DB.ExecContext() with retry-busy timeout. +func (db *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { + err = retryOnBusy(ctx, func() error { + result, err = db.DB.ExecContext(ctx, query, args...) + err = db.ProcessError(err) + return err + }) + return +} + +// QueryContext wraps bun.DB.QueryContext() with retry-busy timeout. +func (db *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { + err = retryOnBusy(ctx, func() error { + rows, err = db.DB.QueryContext(ctx, query, args...) + err = db.ProcessError(err) + return err + }) + return +} + +// QueryRowContext wraps bun.DB.QueryRowContext() with retry-busy timeout. +func (db *WrappedDB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { + _ = retryOnBusy(ctx, func() error { + row = db.DB.QueryRowContext(ctx, query, args...) + err := db.ProcessError(row.Err()) + return err + }) + return +} + +// RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts. +func (db *WrappedDB) RunInTx(ctx context.Context, fn func(bun.Tx) error) error { + // Attempt to start new transaction. + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + var done bool + + defer func() { + if !done { + // Rollback (with retry-backoff). + _ = retryOnBusy(ctx, func() error { + err := tx.Rollback() + return db.errHook(err) + }) + } + }() + + // Perform supplied transaction + if err := fn(tx); err != nil { + return db.errHook(err) + } + + // Commit (with retry-backoff). + err = retryOnBusy(ctx, func() error { + err := tx.Commit() + return db.errHook(err) + }) + done = true + return err +} + +func (db *WrappedDB) NewValues(model interface{}) *bun.ValuesQuery { + return bun.NewValuesQuery(db.DB, model).Conn(db) +} + +func (db *WrappedDB) NewMerge() *bun.MergeQuery { + return bun.NewMergeQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewSelect() *bun.SelectQuery { + return bun.NewSelectQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewInsert() *bun.InsertQuery { + return bun.NewInsertQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewUpdate() *bun.UpdateQuery { + return bun.NewUpdateQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewDelete() *bun.DeleteQuery { + return bun.NewDeleteQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewRaw(query string, args ...interface{}) *bun.RawQuery { + return bun.NewRawQuery(db.DB, query, args...).Conn(db) +} + +func (db *WrappedDB) NewCreateTable() *bun.CreateTableQuery { + return bun.NewCreateTableQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewDropTable() *bun.DropTableQuery { + return bun.NewDropTableQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewCreateIndex() *bun.CreateIndexQuery { + return bun.NewCreateIndexQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewDropIndex() *bun.DropIndexQuery { + return bun.NewDropIndexQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewTruncateTable() *bun.TruncateTableQuery { + return bun.NewTruncateTableQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewAddColumn() *bun.AddColumnQuery { + return bun.NewAddColumnQuery(db.DB).Conn(db) +} + +func (db *WrappedDB) NewDropColumn() *bun.DropColumnQuery { + return bun.NewDropColumnQuery(db.DB).Conn(db) +} + +// ProcessError processes an error to replace any known values with our own error types, +// making it easier to catch specific situations (e.g. no rows, already exists, etc) +func (db *WrappedDB) ProcessError(err error) error { + if err == nil { + return nil + } + return db.errHook(err) +} + +// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors +func (db *WrappedDB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) { + exists, err := query.Exists(ctx) + switch err { + case nil: + return exists, nil + case sql.ErrNoRows: + return false, nil + default: + return false, err + } +} + +// NotExists is the functional opposite of conn.Exists() +func (db *WrappedDB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) { + exists, err := db.Exists(ctx, query) + return !exists, err +} + +// retryOnBusy will retry given function on returned 'errBusy'. +func retryOnBusy(ctx context.Context, fn func() error) error { + var backoff time.Duration + + for i := 0; ; i++ { + // Perform func. + err := fn() + + if err != errBusy { + // May be nil, or may be + // some other error, either + // way return here. + return err + } + + // backoff according to a multiplier of 2ms * 2^2n, + // up to a maximum possible backoff time of 5 minutes. + // + // this works out as the following: + // 4ms + // 16ms + // 64ms + // 256ms + // 1.024s + // 4.096s + // 16.384s + // 1m5.536s + // 4m22.144s + backoff = 2 * time.Millisecond * (1 << (2*i + 1)) + if backoff >= 5*time.Minute { + break + } + + select { + // Context cancelled. + case <-ctx.Done(): + + // Backoff for some time. + case <-time.After(backoff): + } + } + + return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) +} diff --git a/internal/db/domain.go b/internal/db/domain.go index d859752af..740ccefe6 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -27,29 +27,29 @@ import ( // Domain contains DB functions related to domains and domain blocks. type Domain interface { // CreateDomainBlock puts the given instance-level domain block into the database. - CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error + CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error // GetDomainBlock returns one instance-level domain block with the given domain, if it exists. - GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error) + GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) // GetDomainBlockByID returns one instance-level domain block with the given id, if it exists. - GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error) + GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, error) // GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance. GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) // DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists. - DeleteDomainBlock(ctx context.Context, domain string) Error + DeleteDomainBlock(ctx context.Context, domain string) error // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). - IsDomainBlocked(ctx context.Context, domain string) (bool, Error) + IsDomainBlocked(ctx context.Context, domain string) (bool, error) // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found. - AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error) + AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`). - IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error) + IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error) // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found. - AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error) + AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error) } diff --git a/internal/db/emoji.go b/internal/db/emoji.go index 67d7f7232..c4dabd1aa 100644 --- a/internal/db/emoji.go +++ b/internal/db/emoji.go @@ -31,16 +31,16 @@ const EmojiAllDomains string = "all" // Emoji contains functions for getting emoji in the database. type Emoji interface { // PutEmoji puts one emoji in the database. - PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) Error + PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error // UpdateEmoji updates the given columns of one emoji. // If no columns are specified, every column is updated. UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) error // DeleteEmojiByID deletes one emoji by its database ID. - DeleteEmojiByID(ctx context.Context, id string) Error + DeleteEmojiByID(ctx context.Context, id string) error // GetEmojisByIDs gets emojis for the given IDs. - GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, Error) + GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error) // GetUseableEmojis gets all emojis which are useable by accounts on this instance. - GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, Error) + GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, error) // GetEmojis fetches all emojis with IDs less than 'maxID', up to a maximum of 'limit' emojis. GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) @@ -54,22 +54,22 @@ type Emoji interface { // GetEmojisBy gets emojis based on given parameters. Useful for admin actions. GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error) // GetEmojiByID gets a specific emoji by its database ID. - GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, Error) + GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, error) // GetEmojiByShortcodeDomain gets an emoji based on its shortcode and domain. // For local emoji, domain should be an empty string. - GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, Error) + GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) // GetEmojiByURI returns one emoji based on its ActivityPub URI. - GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, Error) + GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, error) // GetEmojiByStaticURL gets an emoji using the URL of the static version of the emoji image. - GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, Error) + GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, error) // PutEmojiCategory puts one new emoji category in the database. - PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) Error + PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error // GetEmojiCategoriesByIDs gets emoji categories for given IDs. - GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, Error) + GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error) // GetEmojiCategories gets a slice of the names of all existing emoji categories. - GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, Error) + GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, error) // GetEmojiCategory gets one emoji category by its id. - GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, Error) + GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, error) // GetEmojiCategoryByName gets one emoji category by its name. - GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, Error) + GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) } diff --git a/internal/db/error.go b/internal/db/error.go index 99b8c4c61..b8e488297 100644 --- a/internal/db/error.go +++ b/internal/db/error.go @@ -17,18 +17,20 @@ package db -import "fmt" - -// Error denotes a database error. -type Error error +import ( + "database/sql" + "errors" +) var ( - // ErrNoEntries is returned when a caller expected an entry for a query, but none was found. - ErrNoEntries Error = fmt.Errorf("no entries") - // ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. - ErrMultipleEntries Error = fmt.Errorf("multiple entries") + // ErrNoEntries is a direct ptr to sql.ErrNoRows since that is returned regardless + // of DB dialect. It is returned when no rows (entries) can be found for a query. + ErrNoEntries = sql.ErrNoRows + // ErrAlreadyExists is returned when a conflict was encountered in the db when doing an insert. - ErrAlreadyExists Error = fmt.Errorf("already exists") - // ErrUnknown denotes an unknown database error. - ErrUnknown Error = fmt.Errorf("unknown error") + ErrAlreadyExists = errors.New("already exists") + + // ErrBusyTimeout is returned if the database connection indicates the connection is too busy + // to complete the supplied query. This is generally intended to be handled internally by the DB. + ErrBusyTimeout = errors.New("busy timeout") ) diff --git a/internal/db/instance.go b/internal/db/instance.go index ab40c7a82..408522b65 100644 --- a/internal/db/instance.go +++ b/internal/db/instance.go @@ -26,16 +26,16 @@ import ( // Instance contains functions for instance-level actions (counting instance users etc.). type Instance interface { // CountInstanceUsers returns the number of known accounts registered with the given domain. - CountInstanceUsers(ctx context.Context, domain string) (int, Error) + CountInstanceUsers(ctx context.Context, domain string) (int, error) // CountInstanceStatuses returns the number of known statuses posted from the given domain. - CountInstanceStatuses(ctx context.Context, domain string) (int, Error) + CountInstanceStatuses(ctx context.Context, domain string) (int, error) // CountInstanceDomains returns the number of known instances known that the given domain federates with. - CountInstanceDomains(ctx context.Context, domain string) (int, Error) + CountInstanceDomains(ctx context.Context, domain string) (int, error) // GetInstance returns the instance entry for the given domain, if it exists. - GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error) + GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error) // GetInstanceByID returns the instance entry corresponding to the given id, if it exists. GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) @@ -47,12 +47,12 @@ type Instance interface { UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. - GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) + GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, error) // GetInstancePeers returns a slice of instances that the host instance knows about. - GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, Error) + GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, error) // GetInstanceModeratorAddresses returns a slice of email addresses belonging to active // (as in, not suspended) moderators + admins on this instance. - GetInstanceModeratorAddresses(ctx context.Context) ([]string, Error) + GetInstanceModeratorAddresses(ctx context.Context) ([]string, error) } diff --git a/internal/db/media.go b/internal/db/media.go index 5fb18a8fe..66fa258fe 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -27,7 +27,7 @@ import ( // Media contains functions related to creating/getting/removing media attachments. type Media interface { // GetAttachmentByID gets a single attachment by its ID. - GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error) + GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, error) // GetAttachmentsByIDs fetches a list of media attachments for given IDs. GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) @@ -49,25 +49,25 @@ type Media interface { // GetCachedAttachmentsOlderThan gets limit n remote attachments (including avatars and headers) older than // the given time. These will be returned in order of attachment.created_at descending (i.e. newest to oldest). - GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, Error) + GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) // CountRemoteOlderThan is like GetRemoteOlderThan, except instead of getting limit n attachments, // it just counts how many remote attachments in the database (including avatars and headers) meet // the olderThan criteria. - CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, Error) + CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, error) // GetAvatarsAndHeaders fetches limit n avatars and headers with an id < maxID. These headers // and avis may be in use or not; the caller should check this if it's important. - GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, Error) + GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) // GetLocalUnattachedOlderThan fetches limit n local media attachments (including avatars and headers), older than // the given time, which aren't header or avatars, and aren't attached to a status. In other words, attachments which were // uploaded but never used for whatever reason, or attachments that were attached to a status which was subsequently deleted. // // These will be returned in order of attachment.created_at descending (newest to oldest in other words). - GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, Error) + GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) // CountLocalUnattachedOlderThan is like GetLocalUnattachedOlderThan, except instead of getting limit n attachments, // it just counts how many local attachments in the database meet the olderThan criteria. - CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, Error) + CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, error) } diff --git a/internal/db/mention.go b/internal/db/mention.go index 348f946a2..d4125031e 100644 --- a/internal/db/mention.go +++ b/internal/db/mention.go @@ -26,10 +26,10 @@ import ( // Mention contains functions for getting/creating mentions in the database. type Mention interface { // GetMention gets a single mention by ID - GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error) + GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) // GetMentions gets multiple mentions. - GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error) + GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) // PutMention will insert the given mention into the database. PutMention(ctx context.Context, mention *gtsmodel.Mention) error diff --git a/internal/db/notification.go b/internal/db/notification.go index c17cf3d93..5f8766252 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -28,21 +28,21 @@ type Notification interface { // GetNotifications returns a slice of notifications that pertain to the given accountID. // // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). - GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, excludeTypes []string) ([]*gtsmodel.Notification, Error) + GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, excludeTypes []string) ([]*gtsmodel.Notification, error) // GetNotification returns one notification according to its id. - GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error) + GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) // GetNotification gets one notification according to the provided parameters, if it exists. // Since not all notifications are about a status, statusID can be an empty string. - GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, Error) + GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error) // PutNotification will insert the given notification into the database. PutNotification(ctx context.Context, notif *gtsmodel.Notification) error // DeleteNotificationByID deletes one notification according to its id, // and removes that notification from the in-memory cache. - DeleteNotificationByID(ctx context.Context, id string) Error + DeleteNotificationByID(ctx context.Context, id string) error // DeleteNotifications mass deletes notifications targeting targetAccountID // and/or originating from originAccountID. @@ -57,10 +57,10 @@ type Notification interface { // originate from originAccountID will be deleted. // // At least one parameter must not be an empty string. - DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error + DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error // DeleteNotificationsForStatus deletes all notifications that relate to // the given statusID. This function is useful when a status has been deleted, // and so notifications relating to that status must also be deleted. - DeleteNotificationsForStatus(ctx context.Context, statusID string) Error + DeleteNotificationsForStatus(ctx context.Context, statusID string) error } diff --git a/internal/db/relationship.go b/internal/db/relationship.go index e4b81c003..f8866a545 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -26,7 +26,7 @@ import ( // Relationship contains functions for getting or modifying the relationship between two accounts. type Relationship interface { // IsBlocked checks whether source account has a block in place against target. - IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) + IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) // IsEitherBlocked checks whether there is a block in place between either of account1 and account2. IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) @@ -53,7 +53,7 @@ type Relationship interface { DeleteAccountBlocks(ctx context.Context, accountID string) error // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. - GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) + GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) // GetFollowByID fetches follow with given ID from the database. GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) @@ -77,13 +77,13 @@ type Relationship interface { GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. - IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) + IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. - IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) + IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. - IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) + IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) // PutFollow attempts to place the given account follow in the database. PutFollow(ctx context.Context, follow *gtsmodel.Follow) error @@ -125,10 +125,10 @@ type Relationship interface { // In other words, it should create the follow, and delete the existing follow request. // // It will return the newly created follow for further processing. - AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) + AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) // RejectFollowRequest fetches a follow request from the database, and then deletes it. - RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error + RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error // GetAccountFollows returns a slice of follows owned by the given accountID. GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) diff --git a/internal/db/report.go b/internal/db/report.go index 3b58644cf..f39e53140 100644 --- a/internal/db/report.go +++ b/internal/db/report.go @@ -26,18 +26,18 @@ import ( // Report handles getting/creation/deletion/updating of user reports/flags. type Report interface { // GetReportByID gets one report by its db id - GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, Error) + GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, error) // GetReports gets limit n reports using the given parameters. // Parameters that are empty / zero are ignored. - GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, Error) + GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, error) // PutReport puts the given report in the database. - PutReport(ctx context.Context, report *gtsmodel.Report) Error + PutReport(ctx context.Context, report *gtsmodel.Report) error // UpdateReport updates one report by its db id. // The given columns will be updated; if no columns are // provided, then all columns will be updated. // updated_at will also be updated, no need to pass this // as a specific column. - UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, Error) + UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, error) // DeleteReportByID deletes report with the given id. - DeleteReportByID(ctx context.Context, id string) Error + DeleteReportByID(ctx context.Context, id string) error } diff --git a/internal/db/session.go b/internal/db/session.go index de6ef2eb1..944fa4215 100644 --- a/internal/db/session.go +++ b/internal/db/session.go @@ -25,5 +25,5 @@ import ( // Session handles getting/creation of router sessions. type Session interface { - GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error) + GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) } diff --git a/internal/db/status.go b/internal/db/status.go index c0e330260..6f9848f57 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -26,34 +26,34 @@ import ( // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. type Status interface { // GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs - GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) + GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) // GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs - GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) + GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error) // GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs - GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error) + GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, error) // PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc). PopulateStatus(ctx context.Context, status *gtsmodel.Status) error // PutStatus stores one status in the database. - PutStatus(ctx context.Context, status *gtsmodel.Status) Error + PutStatus(ctx context.Context, status *gtsmodel.Status) error // UpdateStatus updates one status in the database. - UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) Error + UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) error // DeleteStatusByID deletes one status from the database. - DeleteStatusByID(ctx context.Context, id string) Error + DeleteStatusByID(ctx context.Context, id string) error // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong - CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error) + CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error) // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong - CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error) + CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error) // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong - CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error) + CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error) // GetStatuses gets a slice of statuses corresponding to the given status IDs. GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) @@ -64,26 +64,26 @@ type Status interface { // GetStatusParents gets the parent statuses of a given status. // // If onlyDirect is true, only the immediate parent will be returned. - GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) + GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) // GetStatusChildren gets the child statuses of a given status. // // If onlyDirect is true, only the immediate children will be returned. - GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) + GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) // IsStatusFavedBy checks if a given status has been faved by a given account ID - IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID - IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) // IsStatusMutedBy checks if a given status has been muted by a given account ID - IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID - IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error) + GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) } diff --git a/internal/db/statusbookmark.go b/internal/db/statusbookmark.go index e47bfc54d..542e49f17 100644 --- a/internal/db/statusbookmark.go +++ b/internal/db/statusbookmark.go @@ -25,24 +25,24 @@ import ( type StatusBookmark interface { // GetStatusBookmark gets one status bookmark with the given ID. - GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, Error) + GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, error) // GetStatusBookmarkID is a shortcut function for returning just the database ID // of a status bookmark created by the given accountID, targeting the given statusID. - GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, Error) + GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, error) // GetStatusBookmarks retrieves status bookmarks created by the given accountID, // and using the provided parameters. If limit is < 0 then no limit will be set. // // This function is primarily useful for paging through bookmarks in a sort of // timeline view. - GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, Error) + GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, error) // PutStatusBookmark inserts the given statusBookmark into the database. - PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) Error + PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) error // DeleteStatusBookmark deletes one status bookmark with the given ID. - DeleteStatusBookmark(ctx context.Context, id string) Error + DeleteStatusBookmark(ctx context.Context, id string) error // DeleteStatusBookmarks mass deletes status bookmarks targeting targetAccountID // and/or originating from originAccountID and/or bookmarking statusID. @@ -57,10 +57,10 @@ type StatusBookmark interface { // originate from originAccountID will be deleted. // // At least one parameter must not be an empty string. - DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) Error + DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) error // DeleteStatusBookmarksForStatus deletes all status bookmarks that target the // given status ID. This is useful when a status has been deleted, and you need // to clean up after it. - DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) Error + DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) error } diff --git a/internal/db/statusfave.go b/internal/db/statusfave.go index 98ff1d69d..37769ff79 100644 --- a/internal/db/statusfave.go +++ b/internal/db/statusfave.go @@ -26,23 +26,23 @@ import ( type StatusFave interface { // GetStatusFaveByAccountID gets one status fave created by the given // accountID, targeting the given statusID. - GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error) + GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) // GetStatusFave returns one status fave with the given id. - GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error) + GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, error) // GetStatusFaves returns a slice of faves/likes of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error) + GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) // PopulateStatusFave ensures that all sub-models of a fave are populated (account, status, etc). PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error // PutStatusFave inserts the given statusFave into the database. - PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error + PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error // DeleteStatusFave deletes one status fave with the given id. - DeleteStatusFaveByID(ctx context.Context, id string) Error + DeleteStatusFaveByID(ctx context.Context, id string) error // DeleteStatusFaves mass deletes status faves targeting targetAccountID // and/or originating from originAccountID and/or faving statusID. @@ -57,10 +57,10 @@ type StatusFave interface { // originate from originAccountID will be deleted. // // At least one parameter must not be an empty string. - DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) Error + DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error // DeleteStatusFavesForStatus deletes all status faves that target the // given status ID. This is useful when a status has been deleted, and you need // to clean up after it. - DeleteStatusFavesForStatus(ctx context.Context, statusID string) Error + DeleteStatusFavesForStatus(ctx context.Context, statusID string) error } diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 2635bece2..40d5b8015 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -28,13 +28,13 @@ type Timeline interface { // GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id. // // Statuses should be returned in descending order of when they were created (newest first). - GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) // GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public. // It will use the given filters and try to return as many statuses as possible up to the limit. // // Statuses should be returned in descending order of when they were created (newest first). - GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) // GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. // It will use the given filters and try to return as many statuses as possible up to the limit. @@ -43,7 +43,7 @@ type Timeline interface { // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. // // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. - GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) + GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) // GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID. // Statuses should be returned in descending order of when they were created (newest first). diff --git a/internal/db/tombstone.go b/internal/db/tombstone.go index e3db07cec..362b6747d 100644 --- a/internal/db/tombstone.go +++ b/internal/db/tombstone.go @@ -26,14 +26,14 @@ import ( // Tombstone contains functionality for storing + retrieving tombstones for remote AP Activities + Objects. type Tombstone interface { // GetTombstoneByURI attempts to fetch a tombstone by the given URI. - GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, Error) + GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) // TombstoneExistsWithURI returns true if a tombstone with the given URI exists. - TombstoneExistsWithURI(ctx context.Context, uri string) (bool, Error) + TombstoneExistsWithURI(ctx context.Context, uri string) (bool, error) // PutTombstone creates a new tombstone in the database. - PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) Error + PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error // DeleteTombstone deletes a tombstone with the given ID. - DeleteTombstone(ctx context.Context, id string) Error + DeleteTombstone(ctx context.Context, id string) error } diff --git a/internal/db/user.go b/internal/db/user.go index 15165d0be..9df672837 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -26,21 +26,21 @@ import ( // User contains functions related to user getting/setting/creation. type User interface { // GetAllUsers returns all local user accounts, or an error if something goes wrong. - GetAllUsers(ctx context.Context) ([]*gtsmodel.User, Error) + GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) // GetUserByID returns one user with the given ID, or an error if something goes wrong. - GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error) + GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) // GetUserByAccountID returns one user by its account ID, or an error if something goes wrong. - GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error) + GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) // GetUserByID returns one user with the given email address, or an error if something goes wrong. - GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error) + GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) // GetUserByExternalID returns one user with the given external id, or an error if something goes wrong. - GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, Error) + GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) // GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong. - GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error) + GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) // PutUser will attempt to place user in the database - PutUser(ctx context.Context, user *gtsmodel.User) Error + PutUser(ctx context.Context, user *gtsmodel.User) error // UpdateUser updates one user by its primary key, updating either only the specified columns, or all of them. - UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) Error + UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) error // DeleteUserByID deletes one user by its ID. - DeleteUserByID(ctx context.Context, userID string) Error + DeleteUserByID(ctx context.Context, userID string) error } diff --git a/internal/federation/dereferencing/account_test.go b/internal/federation/dereferencing/account_test.go index 71028e342..3b6994f08 100644 --- a/internal/federation/dereferencing/account_test.go +++ b/internal/federation/dereferencing/account_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -175,7 +176,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername() config.GetHost(), ) suite.True(gtserror.Unretrievable(err)) - suite.EqualError(err, "no entries") + suite.EqualError(err, db.ErrNoEntries.Error()) suite.Nil(fetchedAccount) } @@ -189,7 +190,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDom "localhost:8080", ) suite.True(gtserror.Unretrievable(err)) - suite.EqualError(err, "no entries") + suite.EqualError(err, db.ErrNoEntries.Error()) suite.Nil(fetchedAccount) } @@ -202,7 +203,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() { testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"), ) suite.True(gtserror.Unretrievable(err)) - suite.EqualError(err, "no entries") + suite.EqualError(err, db.ErrNoEntries.Error()) suite.Nil(fetchedAccount) } diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index 65c521113..18bbe1ee9 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -232,7 +232,7 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e return nil, err } - l.Infof("performing request") + l.Info("performing request") // Perform the request. rsp, err = c.do(r) diff --git a/internal/middleware/signaturecheck.go b/internal/middleware/signaturecheck.go index df2ac0300..87c7aac01 100644 --- a/internal/middleware/signaturecheck.go +++ b/internal/middleware/signaturecheck.go @@ -22,7 +22,6 @@ import ( "net/http" "net/url" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -48,7 +47,7 @@ const ( // context for use down the line. // // In case of an error, the request will be aborted with http code 500. -func SignatureCheck(uriBlocked func(context.Context, *url.URL) (bool, db.Error)) func(*gin.Context) { +func SignatureCheck(uriBlocked func(context.Context, *url.URL) (bool, error)) func(*gin.Context) { return func(c *gin.Context) { ctx := c.Request.Context() diff --git a/internal/processing/stream/authorize_test.go b/internal/processing/stream/authorize_test.go index b2c98c5f1..cb91d5b30 100644 --- a/internal/processing/stream/authorize_test.go +++ b/internal/processing/stream/authorize_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" ) type AuthorizeTestSuite struct { @@ -38,7 +39,7 @@ func (suite *AuthorizeTestSuite) TestAuthorize() { suite.Equal(suite.testAccounts["local_account_2"].ID, account2.ID) noAccount, err := suite.streamProcessor.Authorize(context.Background(), "aaaaaaaaaaaaaaaaaaaaa!!") - suite.EqualError(err, "could not load access token: no entries") + suite.EqualError(err, "could not load access token: "+db.ErrNoEntries.Error()) suite.Nil(noAccount) } diff --git a/internal/web/web.go b/internal/web/web.go index c53433730..5c1c4750d 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -62,7 +62,7 @@ const ( type Module struct { processor *processing.Processor eTagCache cache.Cache[string, eTagCacheEntry] - isURIBlocked func(context.Context, *url.URL) (bool, db.Error) + isURIBlocked func(context.Context, *url.URL) (bool, error) } func New(db db.DB, processor *processing.Processor) *Module {