From cf20397f261becaf84d4d3e3f6620d1366b34131 Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Thu, 1 Dec 2022 16:06:09 +0100 Subject: [PATCH] [bugfix] Use case-insensitive selects when getting remote accounts by username/domain (#1191) * [bugfix] Case-insensitive account selection * don't lowercase cache key --- internal/db/bundb/account.go | 11 +++++++---- internal/db/bundb/account_test.go | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 712e0c1c7..ea0852d77 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "codeberg.org/gruf/go-cache/v3/result" @@ -108,11 +109,13 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str q := a.newAccountQ(account) if domain != "" { - q = q.Where("? = ?", bun.Ident("account.username"), username) - q = q.Where("? = ?", bun.Ident("account.domain"), domain) + q = q. + Where("LOWER(?) = ?", bun.Ident("account.username"), strings.ToLower(username)). + Where("? = ?", bun.Ident("account.domain"), domain) } else { - q = q.Where("? = ?", bun.Ident("account.username"), username) - q = q.Where("? IS NULL", bun.Ident("account.domain")) + q = q. + Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)). // usernames on our instance are always lowercase + Where("? IS NULL", bun.Ident("account.domain")) } return q.Scan(ctx) diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 50603623f..bf85f14f4 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -22,6 +22,7 @@ import ( "context" "crypto/rand" "crypto/rsa" + "strings" "testing" "time" @@ -84,6 +85,22 @@ func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() { suite.NotNil(account2) } +func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() { + testAccount := suite.testAccounts["remote_account_2"] + + account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain) + suite.NoError(err) + suite.NotNil(account1) + + account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain) + suite.NoError(err) + suite.NotNil(account2) + + account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain) + suite.NoError(err) + suite.NotNil(account3) +} + func (suite *AccountTestSuite) TestUpdateAccount() { ctx := context.Background()