mirror of
1
Fork 0

[bugfix] Fix potential dereference of accounts on own instance (#757)

* add GetAccountByUsernameDomain

* simplify search

* add escape to not deref accounts on own domain

* check if local + we have account by ap uri
This commit is contained in:
tobi 2022-08-20 22:47:19 +02:00 committed by GitHub
parent 2ca234f42e
commit 570fa7c359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 243 additions and 92 deletions

View File

@ -37,6 +37,7 @@ func NewAccountCache() *AccountCache {
RegisterLookups: func(lm *cache.LookupMap[string, string]) { RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("uri") lm.RegisterLookup("uri")
lm.RegisterLookup("url") lm.RegisterLookup("url")
lm.RegisterLookup("usernamedomain")
}, },
AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
@ -46,6 +47,7 @@ func NewAccountCache() *AccountCache {
if url := acc.URL; url != "" { if url := acc.URL; url != "" {
lm.Set("url", url, acc.ID) lm.Set("url", url, acc.ID)
} }
lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID)
}, },
DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
@ -55,6 +57,7 @@ func NewAccountCache() *AccountCache {
if url := acc.URL; url != "" { if url := acc.URL; url != "" {
lm.Delete("url", url) lm.Delete("url", url)
} }
lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain))
}, },
}) })
c.cache.SetTTL(time.Minute*5, false) c.cache.SetTTL(time.Minute*5, false)
@ -77,6 +80,10 @@ func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("uri", uri) return c.cache.GetBy("uri", uri)
} }
func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain))
}
// Put places a account in the cache, ensuring that the object place is a copy for thread-safety // Put places a account in the cache, ensuring that the object place is a copy for thread-safety
func (c *AccountCache) Put(account *gtsmodel.Account) { func (c *AccountCache) Put(account *gtsmodel.Account) {
if account == nil || account.ID == "" { if account == nil || account.ID == "" {
@ -135,3 +142,11 @@ func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {
SuspensionOrigin: account.SuspensionOrigin, SuspensionOrigin: account.SuspensionOrigin,
} }
} }
func usernameDomainKey(username string, domain string) string {
u := "@" + username
if domain != "" {
return u + "@" + domain
}
return u
}

View File

@ -69,6 +69,10 @@ func (suite *AccountCacheTestSuite) TestAccountCache() {
if account.URL != "" && !ok && !accountIs(account, check) { if account.URL != "" && !ok && !accountIs(account, check) {
suite.Fail("Failed to fetch expected account with URL: %s", account.URL) suite.Fail("Failed to fetch expected account with URL: %s", account.URL)
} }
check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain)
if !ok && !accountIs(account, check) {
suite.Fail("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain)
}
} }
} }

View File

@ -36,6 +36,9 @@ type Account interface {
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong. // GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong.
GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error)
// UpdateAccount updates one account by ID. // UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)

View File

@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
) )
} }
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
func() (*gtsmodel.Account, bool) {
return a.cache.GetByUsernameDomain(username, domain)
},
func(account *gtsmodel.Account) error {
q := a.newAccountQ(account).Where("account.username = ?", username)
if domain != "" {
q = q.Where("account.domain = ?", domain)
} else {
q = q.Where("account.domain IS NULL")
}
return q.Scan(ctx)
},
)
}
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
// Attempt to fetch cached account // Attempt to fetch cached account
account, cached := cacheGet() account, cached := cacheGet()

View File

@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL) suite.NotEmpty(account.HeaderMediaAttachment.URL)
} }
func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
testAccount1 := suite.testAccounts["local_account_1"]
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
suite.NoError(err)
suite.NotNil(account1)
testAccount2 := suite.testAccounts["remote_account_1"]
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
suite.NoError(err)
suite.NotNil(account2)
}
func (suite *AccountTestSuite) TestUpdateAccount() { func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount := suite.testAccounts["local_account_1"] testAccount := suite.testAccounts["local_account_1"]

View File

@ -32,6 +32,7 @@ import (
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
@ -78,7 +79,10 @@ type GetRemoteAccountParams struct {
// GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account, // GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,
// puts or updates it in the database (if necessary), and returns it to a caller. // puts or updates it in the database (if necessary), and returns it to a caller.
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (remoteAccount *gtsmodel.Account, err error) { //
// If a local account is passed into this function for whatever reason (hey, it happens!), then it
// will be returned from the database without making any remote calls.
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (foundAccount *gtsmodel.Account, err error) {
/* /*
In this function we want to retrieve a gtsmodel representation of a remote account, with its proper In this function we want to retrieve a gtsmodel representation of a remote account, with its proper
accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth. accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth.
@ -99,23 +103,40 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
from that. from that.
*/ */
// first check if we can retrieve the account locally just with what we've been given skipResolve := params.SkipResolve
// this first step checks if we have the
// account in the database somewhere already
switch { switch {
case params.RemoteAccountID != nil: case params.RemoteAccountID != nil:
// try with uri uri := params.RemoteAccountID
if a, dbErr := d.db.GetAccountByURI(ctx, params.RemoteAccountID.String()); dbErr == nil { host := uri.Host
remoteAccount = a if host == config.GetHost() || host == config.GetAccountDomain() {
// this is actually a local account,
// make sure we don't try to resolve
skipResolve = true
}
if a, dbErr := d.db.GetAccountByURI(ctx, uri.String()); dbErr == nil {
foundAccount = a
} else if dbErr != db.ErrNoEntries { } else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for account %s: %s", params.RemoteAccountID, err) err = fmt.Errorf("GetRemoteAccount: database error looking for account with uri %s: %s", uri, err)
}
case params.RemoteAccountUsername != "" && (params.RemoteAccountHost == "" || params.RemoteAccountHost == config.GetHost() || params.RemoteAccountHost == config.GetAccountDomain()):
// either no domain is provided or this seems
// to be a local account, so don't resolve
skipResolve = true
if a, dbErr := d.db.GetLocalAccountByUsername(ctx, params.RemoteAccountUsername); dbErr == nil {
foundAccount = a
} else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for local account with username %s: %s", params.RemoteAccountUsername, err)
} }
case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "": case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "":
// try with username/host if a, dbErr := d.db.GetAccountByUsernameDomain(ctx, params.RemoteAccountUsername, params.RemoteAccountHost); dbErr == nil {
a := &gtsmodel.Account{} foundAccount = a
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: params.RemoteAccountHost}}
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
remoteAccount = a
} else if dbErr != db.ErrNoEntries { } else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err) err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and domain %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
} }
default: default:
err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account") err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account")
@ -125,10 +146,11 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
return return
} }
if params.SkipResolve { if skipResolve {
// if we can't resolve, return already since there's nothing more we can do // if we can't resolve, return already
if remoteAccount == nil { // since there's nothing more we can do
err = errors.New("GetRemoteAccount: error retrieving account with skipResolve set true") if foundAccount == nil {
err = errors.New("GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
} }
return return
} }
@ -141,8 +163,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// ... but we still need the username so we can do a finger for the accountDomain // ... but we still need the username so we can do a finger for the accountDomain
// check if we had the account stored already and got it earlier // check if we had the account stored already and got it earlier
if remoteAccount != nil { if foundAccount != nil {
params.RemoteAccountUsername = remoteAccount.Username params.RemoteAccountUsername = foundAccount.Username
} else { } else {
// if we didn't already have it, we have dereference it from remote and just... // if we didn't already have it, we have dereference it from remote and just...
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID) accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
@ -167,8 +189,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// already about what the account domain might be; this var will be overwritten later if necessary // already about what the account domain might be; this var will be overwritten later if necessary
var accountDomain string var accountDomain string
switch { switch {
case remoteAccount != nil: case foundAccount != nil:
accountDomain = remoteAccount.Domain accountDomain = foundAccount.Domain
case params.RemoteAccountID != nil: case params.RemoteAccountID != nil:
accountDomain = params.RemoteAccountID.Host accountDomain = params.RemoteAccountID.Host
default: default:
@ -178,7 +200,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't // to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't
// fingered the remote account for at least 2 days; don't finger instance accounts // fingered the remote account for at least 2 days; don't finger instance accounts
var fingered time.Time var fingered time.Time
if remoteAccount == nil || (remoteAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(remoteAccount)) { if foundAccount == nil || (foundAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(foundAccount)) {
accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost) accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost)
if err != nil { if err != nil {
err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err) err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err)
@ -187,14 +209,14 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
fingered = time.Now() fingered = time.Now()
} }
if !fingered.IsZero() && remoteAccount == nil { if !fingered.IsZero() && foundAccount == nil {
// if we just fingered and now have a discovered account domain but still no account, // if we just fingered and now have a discovered account domain but still no account,
// we should do a final lookup in the database with the discovered username + accountDomain // we should do a final lookup in the database with the discovered username + accountDomain
// to make absolutely sure we don't already have this account // to make absolutely sure we don't already have this account
a := &gtsmodel.Account{} a := &gtsmodel.Account{}
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}} where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}}
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil { if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
remoteAccount = a foundAccount = a
} else if dbErr != db.ErrNoEntries { } else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err) err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
return return
@ -203,7 +225,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// we may also have some extra information already, like the account we had in the db, or the // we may also have some extra information already, like the account we had in the db, or the
// accountable representation that we dereferenced from remote // accountable representation that we dereferenced from remote
if remoteAccount == nil { if foundAccount == nil {
// we still don't have the account, so deference it if we didn't earlier // we still don't have the account, so deference it if we didn't earlier
if accountable == nil { if accountable == nil {
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID) accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
@ -214,7 +236,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
} }
// then convert // then convert
remoteAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false) foundAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)
if err != nil { if err != nil {
err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err) err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err)
return return
@ -227,18 +249,18 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err) err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err)
return return
} }
remoteAccount.ID = ulid foundAccount.ID = ulid
_, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking) _, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
if err != nil { if err != nil {
err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err) err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err)
return return
} }
remoteAccount.LastWebfingeredAt = fingered foundAccount.LastWebfingeredAt = fingered
remoteAccount.UpdatedAt = time.Now() foundAccount.UpdatedAt = time.Now()
err = d.db.Put(ctx, remoteAccount) err = d.db.Put(ctx, foundAccount)
if err != nil { if err != nil {
err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err) err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err)
return return
@ -248,9 +270,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
} }
// we had the account already, but now we know the account domain, so update it if it's different // we had the account already, but now we know the account domain, so update it if it's different
if !strings.EqualFold(remoteAccount.Domain, accountDomain) { if !strings.EqualFold(foundAccount.Domain, accountDomain) {
remoteAccount.Domain = accountDomain foundAccount.Domain = accountDomain
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount) foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
if err != nil { if err != nil {
err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err) err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err)
return return
@ -260,7 +282,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// make sure the account fields are populated before returning: // make sure the account fields are populated before returning:
// the caller might want to block until everything is loaded // the caller might want to block until everything is loaded
var fieldsChanged bool var fieldsChanged bool
fieldsChanged, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking) fieldsChanged, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err) return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err)
} }
@ -268,12 +290,12 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
var fingeredChanged bool var fingeredChanged bool
if !fingered.IsZero() { if !fingered.IsZero() {
fingeredChanged = true fingeredChanged = true
remoteAccount.LastWebfingeredAt = fingered foundAccount.LastWebfingeredAt = fingered
} }
if fieldsChanged || fingeredChanged { if fieldsChanged || fingeredChanged {
remoteAccount.UpdatedAt = time.Now() foundAccount.UpdatedAt = time.Now()
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount) foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err) return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err)
} }

View File

@ -21,9 +21,11 @@ package dereferencing_test
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -42,11 +44,11 @@ func (suite *AccountTestSuite) TestDereferenceGroup() {
}) })
suite.NoError(err) suite.NoError(err)
suite.NotNil(group) suite.NotNil(group)
suite.NotNil(group)
// group values should be set // group values should be set
suite.Equal("https://unknown-instance.com/groups/some_group", group.URI) suite.Equal("https://unknown-instance.com/groups/some_group", group.URI)
suite.Equal("https://unknown-instance.com/@some_group", group.URL) suite.Equal("https://unknown-instance.com/@some_group", group.URL)
suite.WithinDuration(time.Now(), group.LastWebfingeredAt, 5*time.Second)
// group should be in the database // group should be in the database
dbGroup, err := suite.db.GetAccountByURI(context.Background(), group.URI) dbGroup, err := suite.db.GetAccountByURI(context.Background(), group.URI)
@ -65,11 +67,11 @@ func (suite *AccountTestSuite) TestDereferenceService() {
}) })
suite.NoError(err) suite.NoError(err)
suite.NotNil(service) suite.NotNil(service)
suite.NotNil(service)
// service values should be set // service values should be set
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URI) suite.Equal("https://owncast.example.org/federation/user/rgh", service.URI)
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URL) suite.Equal("https://owncast.example.org/federation/user/rgh", service.URL)
suite.WithinDuration(time.Now(), service.LastWebfingeredAt, 5*time.Second)
// service should be in the database // service should be in the database
dbService, err := suite.db.GetAccountByURI(context.Background(), service.URI) dbService, err := suite.db.GetAccountByURI(context.Background(), service.URI)
@ -79,6 +81,102 @@ func (suite *AccountTestSuite) TestDereferenceService() {
suite.Equal("example.org", dbService.Domain) suite.Equal("example.org", dbService.Domain)
} }
/*
We shouldn't try webfingering or making http calls to dereference local accounts
that might be passed into GetRemoteAccount for whatever reason, so these tests are
here to make sure that such cases are (basically) short-circuit evaluated and given
back as-is without trying to make any calls to one's own instance.
*/
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsRemoteURL() {
fetchingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
RequestingUsername: fetchingAccount.Username,
RemoteAccountID: testrig.URLMustParse(targetAccount.URI),
})
suite.NoError(err)
suite.NotNil(fetchedAccount)
suite.Empty(fetchedAccount.Domain)
}
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsername() {
fetchingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
RequestingUsername: fetchingAccount.Username,
RemoteAccountUsername: targetAccount.Username,
})
suite.NoError(err)
suite.NotNil(fetchedAccount)
suite.Empty(fetchedAccount.Domain)
}
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomain() {
fetchingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
RequestingUsername: fetchingAccount.Username,
RemoteAccountUsername: targetAccount.Username,
RemoteAccountHost: config.GetHost(),
})
suite.NoError(err)
suite.NotNil(fetchedAccount)
suite.Empty(fetchedAccount.Domain)
}
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomainAndURL() {
fetchingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
RequestingUsername: fetchingAccount.Username,
RemoteAccountID: testrig.URLMustParse(targetAccount.URI),
RemoteAccountUsername: targetAccount.Username,
RemoteAccountHost: config.GetHost(),
})
suite.NoError(err)
suite.NotNil(fetchedAccount)
suite.Empty(fetchedAccount.Domain)
}
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername() {
fetchingAccount := suite.testAccounts["local_account_1"]
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
RequestingUsername: fetchingAccount.Username,
RemoteAccountUsername: "thisaccountdoesnotexist",
})
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
suite.Nil(fetchedAccount)
}
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDomain() {
fetchingAccount := suite.testAccounts["local_account_1"]
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
RequestingUsername: fetchingAccount.Username,
RemoteAccountUsername: "thisaccountdoesnotexist",
RemoteAccountHost: "localhost:8080",
})
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
suite.Nil(fetchedAccount)
}
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() {
fetchingAccount := suite.testAccounts["local_account_1"]
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
RequestingUsername: fetchingAccount.Username,
RemoteAccountID: testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"),
})
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
suite.Nil(fetchedAccount)
}
func TestAccountTestSuite(t *testing.T) { func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite)) suite.Run(t, new(AccountTestSuite))
} }

View File

@ -39,7 +39,6 @@ import (
func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) { func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) {
l := log.WithFields(kv.Fields{ l := log.WithFields(kv.Fields{
{"query", search.Query}, {"query", search.Query},
}...) }...)
@ -62,7 +61,7 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
/* /*
SEARCH BY MENTION SEARCH BY MENTION
check if the query is something like @whatever_username@example.org -- this means it's a remote account check if the query is something like @whatever_username@example.org -- this means it's likely a remote account
*/ */
maybeNamestring := query maybeNamestring := query
if maybeNamestring[0] != '@' { if maybeNamestring[0] != '@' {
@ -135,7 +134,6 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) { func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) {
l := log.WithFields(kv.Fields{ l := log.WithFields(kv.Fields{
{"uri", uri.String()}, {"uri", uri.String()},
{"resolve", resolve}, {"resolve", resolve},
}...) }...)
@ -161,67 +159,46 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u
} }
func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) { func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) {
if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil { // it might be a web url like http://example.org/@user instead
return maybeAccount, nil // of an AP uri like http://example.org/users/user, check first
} else if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil { if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil {
return maybeAccount, nil return maybeAccount, nil
} }
if resolve { if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() {
// we don't have it locally so try and dereference it // this is a local account; if we don't have it now then
account, err := p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ // we should just bail instead of trying to get it remote
RequestingUsername: authed.Account.Username, if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil {
RemoteAccountID: uri, return maybeAccount, nil
})
if err != nil {
return nil, fmt.Errorf("searchAccountByURI: error dereferencing account with uri %s: %s", uri.String(), err)
} }
return account, nil return nil, nil
} }
return nil, nil
// we don't have it yet, try to find it remotely
return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
RequestingUsername: authed.Account.Username,
RemoteAccountID: uri,
Blocking: true,
SkipResolve: !resolve,
})
} }
func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) { func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) {
maybeAcct := &gtsmodel.Account{}
var err error
// if it's a local account we can skip a whole bunch of stuff // if it's a local account we can skip a whole bunch of stuff
if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" { if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" {
maybeAcct, err = p.db.GetLocalAccountByUsername(ctx, username) maybeAcct, err := p.db.GetLocalAccountByUsername(ctx, username)
if err != nil && err != db.ErrNoEntries { if err == nil || err == db.ErrNoEntries {
return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err) return maybeAcct, nil
} }
return maybeAcct, nil return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err)
} }
// it's not a local account so first we'll check if it's in the database already... // we don't have it yet, try to find it remotely
where := []db.Where{ return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
{Key: "username", Value: username, CaseInsensitive: true}, RequestingUsername: authed.Account.Username,
{Key: "domain", Value: domain, CaseInsensitive: true}, RemoteAccountUsername: username,
} RemoteAccountHost: domain,
err = p.db.GetWhere(ctx, where, maybeAcct) Blocking: true,
if err == nil { SkipResolve: !resolve,
// we've got it stored locally already! })
return maybeAcct, nil
}
if err != db.ErrNoEntries {
// if it's not errNoEntries there's been a real database error so bail at this point
return nil, fmt.Errorf("searchAccountByMention: database error: %s", err)
}
// we got a db.ErrNoEntries, so we just don't have the account locally stored -- check if we can dereference it
if resolve {
maybeAcct, err = p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
RequestingUsername: authed.Account.Username,
RemoteAccountUsername: username,
RemoteAccountHost: domain,
})
if err != nil {
return nil, fmt.Errorf("searchAccountByMention: error getting remote account: %s", err)
}
return maybeAcct, nil
}
return nil, nil
} }