mirror of
1
Fork 0

[bugfix] Narrow search scope for accounts starting with '@'; don't LOWER SQLite text searches (#2435)

This commit is contained in:
tobi 2023-12-10 14:15:41 +01:00 committed by GitHub
parent d60edf7ec6
commit 3f070a442a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 56 deletions

View File

@ -67,7 +67,7 @@ type searchDB struct {
// WHERE (("account"."domain" IS NULL) OR ("account"."domain" != "account"."username")) // WHERE (("account"."domain" IS NULL) OR ("account"."domain" != "account"."username"))
// AND ("account"."id" < 'ZZZZZZZZZZZZZZZZZZZZZZZZZZ') // AND ("account"."id" < 'ZZZZZZZZZZZZZZZZZZZZZZZZZZ')
// AND ("account"."id" IN (SELECT "target_account_id" FROM "follows" WHERE ("account_id" = '016T5Q3SQKBT337DAKVSKNXXW1'))) // AND ("account"."id" IN (SELECT "target_account_id" FROM "follows" WHERE ("account_id" = '016T5Q3SQKBT337DAKVSKNXXW1')))
// AND ((SELECT LOWER("account"."username" || COALESCE("account"."display_name", '') || COALESCE("account"."note", '')) AS "account_text") LIKE '%turtle%' ESCAPE '\') // AND ((SELECT "account"."username" || COALESCE("account"."display_name", '') || COALESCE("account"."note", '') AS "account_text") LIKE '%turtle%' ESCAPE '\')
// ORDER BY "account"."id" DESC LIMIT 10 // ORDER BY "account"."id" DESC LIMIT 10
func (s *searchDB) SearchForAccounts( func (s *searchDB) SearchForAccounts(
ctx context.Context, ctx context.Context,
@ -128,12 +128,20 @@ func (s *searchDB) SearchForAccounts(
) )
} }
// Select account text as subquery. if strings.HasPrefix(query, "@") {
accountTextSubq := s.accountText(following) // Query looks a bit like a username.
// Normalize it and just look for
// usernames that start with query.
query = query[1:]
subQ := s.accountUsername()
q = whereStartsLike(q, subQ, query)
} else {
// Query looks like arbitrary string.
// Search using LIKE for matches of query // Search using LIKE for matches of query
// string within accountText subquery. // string within accountText subquery.
q = whereLike(q, accountTextSubq, query) subQ := s.accountText(following)
q = whereLike(q, subQ, query)
}
if limit > 0 { if limit > 0 {
// Limit amount of accounts returned. // Limit amount of accounts returned.
@ -191,7 +199,15 @@ func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery {
Where("? = ?", bun.Ident("follow.account_id"), accountID) Where("? = ?", bun.Ident("follow.account_id"), accountID)
} }
// statusText returns a subquery that selects a concatenation // accountUsername returns a subquery that just selects
// from account usernames, without concatenation.
func (s *searchDB) accountUsername() *bun.SelectQuery {
return s.db.
NewSelect().
Column("account.username")
}
// accountText returns a subquery that selects a concatenation
// of account username and display name as "account_text". If // of account username and display name as "account_text". If
// `following` is true, then account note will also be included // `following` is true, then account note will also be included
// in the concatenation. // in the concatenation.
@ -226,14 +242,17 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery {
// different number of placeholders depending on // different number of placeholders depending on
// following/not following. COALESCE calls ensure // following/not following. COALESCE calls ensure
// that we're not trying to concatenate null values. // that we're not trying to concatenate null values.
//
// SQLite search is case insensitive.
// Postgres searches get lowercased.
d := s.db.Dialect().Name() d := s.db.Dialect().Name()
switch { switch {
case d == dialect.SQLite && following: case d == dialect.SQLite && following:
query = "LOWER(? || COALESCE(?, ?) || COALESCE(?, ?)) AS ?" query = "? || COALESCE(?, ?) || COALESCE(?, ?) AS ?"
case d == dialect.SQLite && !following: case d == dialect.SQLite && !following:
query = "LOWER(? || COALESCE(?, ?)) AS ?" query = "? || COALESCE(?, ?) AS ?"
case d == dialect.PG && following: case d == dialect.PG && following:
query = "LOWER(CONCAT(?, COALESCE(?, ?), COALESCE(?, ?))) AS ?" query = "LOWER(CONCAT(?, COALESCE(?, ?), COALESCE(?, ?))) AS ?"
@ -255,7 +274,7 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery {
// WHERE ("status"."boost_of_id" IS NULL) // WHERE ("status"."boost_of_id" IS NULL)
// AND (("status"."account_id" = '01F8MH1H7YV1Z7D2C8K2730QBF') OR ("status"."in_reply_to_account_id" = '01F8MH1H7YV1Z7D2C8K2730QBF')) // AND (("status"."account_id" = '01F8MH1H7YV1Z7D2C8K2730QBF') OR ("status"."in_reply_to_account_id" = '01F8MH1H7YV1Z7D2C8K2730QBF'))
// AND ("status"."id" < 'ZZZZZZZZZZZZZZZZZZZZZZZZZZ') // AND ("status"."id" < 'ZZZZZZZZZZZZZZZZZZZZZZZZZZ')
// AND ((SELECT LOWER("status"."content" || COALESCE("status"."content_warning", '')) AS "status_text") LIKE '%hello%' ESCAPE '\') // AND ((SELECT "status"."content" || COALESCE("status"."content_warning", '') AS "status_text") LIKE '%hello%' ESCAPE '\')
// ORDER BY "status"."id" DESC LIMIT 10 // ORDER BY "status"."id" DESC LIMIT 10
func (s *searchDB) SearchForStatuses( func (s *searchDB) SearchForStatuses(
ctx context.Context, ctx context.Context,
@ -366,11 +385,14 @@ func (s *searchDB) statusText() *bun.SelectQuery {
// SQLite and Postgres use different // SQLite and Postgres use different
// syntaxes for concatenation. // syntaxes for concatenation.
//
// SQLite search is case insensitive.
// Postgres searches get lowercased.
switch s.db.Dialect().Name() { switch s.db.Dialect().Name() {
case dialect.SQLite: case dialect.SQLite:
statusText = statusText.ColumnExpr( statusText = statusText.ColumnExpr(
"LOWER(? || COALESCE(?, ?)) AS ?", "? || COALESCE(?, ?) AS ?",
bun.Ident("status.content"), bun.Ident("status.content_warning"), "", bun.Ident("status.content"), bun.Ident("status.content_warning"), "",
bun.Ident("status_text")) bun.Ident("status_text"))

View File

@ -37,6 +37,24 @@ func (suite *SearchTestSuite) TestSearchAccountsTurtleAny() {
suite.Len(accounts, 1) suite.Len(accounts, 1)
} }
func (suite *SearchTestSuite) TestSearchAccounts1HappyWithPrefix() {
testAccount := suite.testAccounts["local_account_1"]
// Query will just look for usernames that start with "1happy".
accounts, err := suite.db.SearchForAccounts(context.Background(), testAccount.ID, "@1happy", "", "", 10, false, 0)
suite.NoError(err)
suite.Len(accounts, 1)
}
func (suite *SearchTestSuite) TestSearchAccounts1HappyNoPrefix() {
testAccount := suite.testAccounts["local_account_1"]
// Query will do the full coalesce.
accounts, err := suite.db.SearchForAccounts(context.Background(), testAccount.ID, "1happy", "", "", 10, false, 0)
suite.NoError(err)
suite.Len(accounts, 1)
}
func (suite *SearchTestSuite) TestSearchAccountsTurtleFollowing() { func (suite *SearchTestSuite) TestSearchAccountsTurtleFollowing() {
testAccount := suite.testAccounts["local_account_1"] testAccount := suite.testAccounts["local_account_1"]

View File

@ -74,13 +74,6 @@ func (p *Processor) Accounts(
return nil, gtserror.NewErrorBadRequest(err, err.Error()) return nil, gtserror.NewErrorBadRequest(err, err.Error())
} }
// Be nice and normalize query by prepending '@'.
// This will make it easier for accountsByNamestring
// to pick this up as a valid namestring.
if query[0] != '@' {
query = "@" + query
}
log. log.
WithContext(ctx). WithContext(ctx).
WithFields(kv.Fields{ WithFields(kv.Fields{
@ -107,9 +100,7 @@ func (p *Processor) Accounts(
// See if we have something that looks like a namestring. // See if we have something that looks like a namestring.
username, domain, err := util.ExtractNamestringParts(query) username, domain, err := util.ExtractNamestringParts(query)
if err != nil { if err == nil {
log.Warnf(ctx, "couldn't parse '%s' as namestring: %v", query, err)
} else {
if domain != "" { if domain != "" {
// Search was an exact namestring; // Search was an exact namestring;
// we can safely assume caller is // we can safely assume caller is
@ -121,7 +112,7 @@ func (p *Processor) Accounts(
// Get all accounts we can find // Get all accounts we can find
// that match the provided query. // that match the provided query.
if err := p.accountsByNamestring( if err := p.accountsByUsernameDomain(
ctx, ctx,
requestingAccount, requestingAccount,
id.Highest, id.Highest,
@ -137,6 +128,23 @@ func (p *Processor) Accounts(
err = gtserror.Newf("error searching by namestring: %w", err) err = gtserror.Newf("error searching by namestring: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
} else {
// Query Doesn't look like a
// namestring, use text search.
if err := p.accountsByText(
ctx,
requestingAccount.ID,
id.Highest,
id.Lowest,
limit,
offset,
query,
following,
appendAccount,
); err != nil && !errors.Is(err, db.ErrNoEntries) {
err = gtserror.Newf("error searching by text: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
} }
// Return whatever we got (if anything). // Return whatever we got (if anything).

View File

@ -165,13 +165,15 @@ func (p *Processor) Get(
// We managed to parse query as a namestring. // We managed to parse query as a namestring.
// If domain was set, this is a very specific // If domain was set, this is a very specific
// search for a particular account, so show // search for a particular account, so show
// that account to the caller even if they // that account to the caller even if it's an
// have it blocked. They might be looking // instance account and/or even if they have
// for it to unblock it again! // it blocked. They might be looking for it
// to unblock it again!
domainSet := (domain != "") domainSet := (domain != "")
includeInstanceAccounts = domainSet
includeBlockedAccounts = domainSet includeBlockedAccounts = domainSet
err = p.accountsByNamestring( err = p.accountsByUsernameDomain(
ctx, ctx,
account, account,
maxID, maxID,
@ -189,13 +191,11 @@ func (p *Processor) Get(
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// If domain was set, we know this is // Namestrings are a pretty unique format, so
// a full namestring, and not a url or // it's very unlikely that the caller was
// just a username, so we should stop // searching for anything except an account.
// looking now and just return what we // As such, return early without falling
// have (if anything). Otherwise we'll // through to broader search.
// let the search keep going.
if domainSet {
return p.packageSearchResult( return p.packageSearchResult(
ctx, ctx,
account, account,
@ -208,7 +208,6 @@ func (p *Processor) Get(
) )
} }
} }
}
// Check if we're searching by a known URI scheme. // Check if we're searching by a known URI scheme.
// (This might just be a weirdly-parsed URI, // (This might just be a weirdly-parsed URI,
@ -331,12 +330,12 @@ func (p *Processor) Get(
) )
} }
// accountsByNamestring searches for accounts using the // accountsByUsernameDomain searches for accounts using
// provided username and domain. If domain is not set, // the provided username and domain. If domain is not set,
// it may return more than one result by doing a text // it may return more than one result by doing a text
// search in the database for accounts matching the query. // search in the database for accounts matching the query.
// Otherwise, it tries to return an exact match. // Otherwise, it tries to return an exact match.
func (p *Processor) accountsByNamestring( func (p *Processor) accountsByUsernameDomain(
ctx context.Context, ctx context.Context,
requestingAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account,
maxID string, maxID string,
@ -350,10 +349,10 @@ func (p *Processor) accountsByNamestring(
appendAccount func(*gtsmodel.Account), appendAccount func(*gtsmodel.Account),
) error { ) error {
if domain == "" { if domain == "" {
// No error, but no domain set. That means the query // No domain set. That means the query looked
// looked like '@someone' which is not an exact search. // like '@someone' which is not an exact search,
// Try to search for any accounts that match the query // but is still a username search. Look for any
// string, and let the caller know they should stop. // usernames that start with the query string.
return p.accountsByText( return p.accountsByText(
ctx, ctx,
requestingAccount.ID, requestingAccount.ID,
@ -361,15 +360,16 @@ func (p *Processor) accountsByNamestring(
minID, minID,
limit, limit,
offset, offset,
// OK to assume username is set now. Use // Add @ prefix back in to indicate
// it instead of query to omit leading '@'. // to search function that we want
username, // an account by its username.
"@"+username,
following, following,
appendAccount, appendAccount,
) )
} }
// No error, and domain and username were both set. // Domain and username were both set.
// Caller is likely trying to search for an exact // Caller is likely trying to search for an exact
// match, from either a remote instance or local. // match, from either a remote instance or local.
foundAccount, err := p.accountByUsernameDomain( foundAccount, err := p.accountByUsernameDomain(