[bugfix] Fix potential dereference of accounts on own instance (#757)
* add GetAccountByUsernameDomain * simplify search * add escape to not deref accounts on own domain * check if local + we have account by ap uri
This commit is contained in:
parent
2ca234f42e
commit
570fa7c359
|
@ -37,6 +37,7 @@ func NewAccountCache() *AccountCache {
|
||||||
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
|
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
|
||||||
lm.RegisterLookup("uri")
|
lm.RegisterLookup("uri")
|
||||||
lm.RegisterLookup("url")
|
lm.RegisterLookup("url")
|
||||||
|
lm.RegisterLookup("usernamedomain")
|
||||||
},
|
},
|
||||||
|
|
||||||
AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
|
AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
|
||||||
|
@ -46,6 +47,7 @@ func NewAccountCache() *AccountCache {
|
||||||
if url := acc.URL; url != "" {
|
if url := acc.URL; url != "" {
|
||||||
lm.Set("url", url, acc.ID)
|
lm.Set("url", url, acc.ID)
|
||||||
}
|
}
|
||||||
|
lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID)
|
||||||
},
|
},
|
||||||
|
|
||||||
DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
|
DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
|
||||||
|
@ -55,6 +57,7 @@ func NewAccountCache() *AccountCache {
|
||||||
if url := acc.URL; url != "" {
|
if url := acc.URL; url != "" {
|
||||||
lm.Delete("url", url)
|
lm.Delete("url", url)
|
||||||
}
|
}
|
||||||
|
lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain))
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.cache.SetTTL(time.Minute*5, false)
|
c.cache.SetTTL(time.Minute*5, false)
|
||||||
|
@ -77,6 +80,10 @@ func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {
|
||||||
return c.cache.GetBy("uri", uri)
|
return c.cache.GetBy("uri", uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) {
|
||||||
|
return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain))
|
||||||
|
}
|
||||||
|
|
||||||
// Put places a account in the cache, ensuring that the object place is a copy for thread-safety
|
// Put places a account in the cache, ensuring that the object place is a copy for thread-safety
|
||||||
func (c *AccountCache) Put(account *gtsmodel.Account) {
|
func (c *AccountCache) Put(account *gtsmodel.Account) {
|
||||||
if account == nil || account.ID == "" {
|
if account == nil || account.ID == "" {
|
||||||
|
@ -135,3 +142,11 @@ func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {
|
||||||
SuspensionOrigin: account.SuspensionOrigin,
|
SuspensionOrigin: account.SuspensionOrigin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func usernameDomainKey(username string, domain string) string {
|
||||||
|
u := "@" + username
|
||||||
|
if domain != "" {
|
||||||
|
return u + "@" + domain
|
||||||
|
}
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
|
@ -69,6 +69,10 @@ func (suite *AccountCacheTestSuite) TestAccountCache() {
|
||||||
if account.URL != "" && !ok && !accountIs(account, check) {
|
if account.URL != "" && !ok && !accountIs(account, check) {
|
||||||
suite.Fail("Failed to fetch expected account with URL: %s", account.URL)
|
suite.Fail("Failed to fetch expected account with URL: %s", account.URL)
|
||||||
}
|
}
|
||||||
|
check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain)
|
||||||
|
if !ok && !accountIs(account, check) {
|
||||||
|
suite.Fail("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,9 @@ type Account interface {
|
||||||
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
|
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
|
||||||
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
|
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
|
||||||
|
|
||||||
|
// GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong.
|
||||||
|
GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error)
|
||||||
|
|
||||||
// UpdateAccount updates one account by ID.
|
// UpdateAccount updates one account by ID.
|
||||||
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
|
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
|
||||||
|
return a.getAccount(
|
||||||
|
ctx,
|
||||||
|
func() (*gtsmodel.Account, bool) {
|
||||||
|
return a.cache.GetByUsernameDomain(username, domain)
|
||||||
|
},
|
||||||
|
func(account *gtsmodel.Account) error {
|
||||||
|
q := a.newAccountQ(account).Where("account.username = ?", username)
|
||||||
|
|
||||||
|
if domain != "" {
|
||||||
|
q = q.Where("account.domain = ?", domain)
|
||||||
|
} else {
|
||||||
|
q = q.Where("account.domain IS NULL")
|
||||||
|
}
|
||||||
|
|
||||||
|
return q.Scan(ctx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
|
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
|
||||||
// Attempt to fetch cached account
|
// Attempt to fetch cached account
|
||||||
account, cached := cacheGet()
|
account, cached := cacheGet()
|
||||||
|
|
|
@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
|
||||||
suite.NotEmpty(account.HeaderMediaAttachment.URL)
|
suite.NotEmpty(account.HeaderMediaAttachment.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
|
||||||
|
testAccount1 := suite.testAccounts["local_account_1"]
|
||||||
|
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(account1)
|
||||||
|
|
||||||
|
testAccount2 := suite.testAccounts["remote_account_1"]
|
||||||
|
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(account2)
|
||||||
|
}
|
||||||
|
|
||||||
func (suite *AccountTestSuite) TestUpdateAccount() {
|
func (suite *AccountTestSuite) TestUpdateAccount() {
|
||||||
testAccount := suite.testAccounts["local_account_1"]
|
testAccount := suite.testAccounts["local_account_1"]
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ import (
|
||||||
"github.com/superseriousbusiness/activity/streams"
|
"github.com/superseriousbusiness/activity/streams"
|
||||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||||
|
@ -78,7 +79,10 @@ type GetRemoteAccountParams struct {
|
||||||
|
|
||||||
// GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,
|
// GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,
|
||||||
// puts or updates it in the database (if necessary), and returns it to a caller.
|
// puts or updates it in the database (if necessary), and returns it to a caller.
|
||||||
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (remoteAccount *gtsmodel.Account, err error) {
|
//
|
||||||
|
// If a local account is passed into this function for whatever reason (hey, it happens!), then it
|
||||||
|
// will be returned from the database without making any remote calls.
|
||||||
|
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (foundAccount *gtsmodel.Account, err error) {
|
||||||
/*
|
/*
|
||||||
In this function we want to retrieve a gtsmodel representation of a remote account, with its proper
|
In this function we want to retrieve a gtsmodel representation of a remote account, with its proper
|
||||||
accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth.
|
accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth.
|
||||||
|
@ -99,23 +103,40 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
from that.
|
from that.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// first check if we can retrieve the account locally just with what we've been given
|
skipResolve := params.SkipResolve
|
||||||
|
|
||||||
|
// this first step checks if we have the
|
||||||
|
// account in the database somewhere already
|
||||||
switch {
|
switch {
|
||||||
case params.RemoteAccountID != nil:
|
case params.RemoteAccountID != nil:
|
||||||
// try with uri
|
uri := params.RemoteAccountID
|
||||||
if a, dbErr := d.db.GetAccountByURI(ctx, params.RemoteAccountID.String()); dbErr == nil {
|
host := uri.Host
|
||||||
remoteAccount = a
|
if host == config.GetHost() || host == config.GetAccountDomain() {
|
||||||
|
// this is actually a local account,
|
||||||
|
// make sure we don't try to resolve
|
||||||
|
skipResolve = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if a, dbErr := d.db.GetAccountByURI(ctx, uri.String()); dbErr == nil {
|
||||||
|
foundAccount = a
|
||||||
} else if dbErr != db.ErrNoEntries {
|
} else if dbErr != db.ErrNoEntries {
|
||||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account %s: %s", params.RemoteAccountID, err)
|
err = fmt.Errorf("GetRemoteAccount: database error looking for account with uri %s: %s", uri, err)
|
||||||
|
}
|
||||||
|
case params.RemoteAccountUsername != "" && (params.RemoteAccountHost == "" || params.RemoteAccountHost == config.GetHost() || params.RemoteAccountHost == config.GetAccountDomain()):
|
||||||
|
// either no domain is provided or this seems
|
||||||
|
// to be a local account, so don't resolve
|
||||||
|
skipResolve = true
|
||||||
|
|
||||||
|
if a, dbErr := d.db.GetLocalAccountByUsername(ctx, params.RemoteAccountUsername); dbErr == nil {
|
||||||
|
foundAccount = a
|
||||||
|
} else if dbErr != db.ErrNoEntries {
|
||||||
|
err = fmt.Errorf("GetRemoteAccount: database error looking for local account with username %s: %s", params.RemoteAccountUsername, err)
|
||||||
}
|
}
|
||||||
case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "":
|
case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "":
|
||||||
// try with username/host
|
if a, dbErr := d.db.GetAccountByUsernameDomain(ctx, params.RemoteAccountUsername, params.RemoteAccountHost); dbErr == nil {
|
||||||
a := >smodel.Account{}
|
foundAccount = a
|
||||||
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: params.RemoteAccountHost}}
|
|
||||||
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
|
|
||||||
remoteAccount = a
|
|
||||||
} else if dbErr != db.ErrNoEntries {
|
} else if dbErr != db.ErrNoEntries {
|
||||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
|
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and domain %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account")
|
err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account")
|
||||||
|
@ -125,10 +146,11 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.SkipResolve {
|
if skipResolve {
|
||||||
// if we can't resolve, return already since there's nothing more we can do
|
// if we can't resolve, return already
|
||||||
if remoteAccount == nil {
|
// since there's nothing more we can do
|
||||||
err = errors.New("GetRemoteAccount: error retrieving account with skipResolve set true")
|
if foundAccount == nil {
|
||||||
|
err = errors.New("GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -141,8 +163,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
// ... but we still need the username so we can do a finger for the accountDomain
|
// ... but we still need the username so we can do a finger for the accountDomain
|
||||||
|
|
||||||
// check if we had the account stored already and got it earlier
|
// check if we had the account stored already and got it earlier
|
||||||
if remoteAccount != nil {
|
if foundAccount != nil {
|
||||||
params.RemoteAccountUsername = remoteAccount.Username
|
params.RemoteAccountUsername = foundAccount.Username
|
||||||
} else {
|
} else {
|
||||||
// if we didn't already have it, we have dereference it from remote and just...
|
// if we didn't already have it, we have dereference it from remote and just...
|
||||||
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
|
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
|
||||||
|
@ -167,8 +189,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
// already about what the account domain might be; this var will be overwritten later if necessary
|
// already about what the account domain might be; this var will be overwritten later if necessary
|
||||||
var accountDomain string
|
var accountDomain string
|
||||||
switch {
|
switch {
|
||||||
case remoteAccount != nil:
|
case foundAccount != nil:
|
||||||
accountDomain = remoteAccount.Domain
|
accountDomain = foundAccount.Domain
|
||||||
case params.RemoteAccountID != nil:
|
case params.RemoteAccountID != nil:
|
||||||
accountDomain = params.RemoteAccountID.Host
|
accountDomain = params.RemoteAccountID.Host
|
||||||
default:
|
default:
|
||||||
|
@ -178,7 +200,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
// to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't
|
// to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't
|
||||||
// fingered the remote account for at least 2 days; don't finger instance accounts
|
// fingered the remote account for at least 2 days; don't finger instance accounts
|
||||||
var fingered time.Time
|
var fingered time.Time
|
||||||
if remoteAccount == nil || (remoteAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(remoteAccount)) {
|
if foundAccount == nil || (foundAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(foundAccount)) {
|
||||||
accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost)
|
accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err)
|
err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err)
|
||||||
|
@ -187,14 +209,14 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
fingered = time.Now()
|
fingered = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !fingered.IsZero() && remoteAccount == nil {
|
if !fingered.IsZero() && foundAccount == nil {
|
||||||
// if we just fingered and now have a discovered account domain but still no account,
|
// if we just fingered and now have a discovered account domain but still no account,
|
||||||
// we should do a final lookup in the database with the discovered username + accountDomain
|
// we should do a final lookup in the database with the discovered username + accountDomain
|
||||||
// to make absolutely sure we don't already have this account
|
// to make absolutely sure we don't already have this account
|
||||||
a := >smodel.Account{}
|
a := >smodel.Account{}
|
||||||
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}}
|
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}}
|
||||||
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
|
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
|
||||||
remoteAccount = a
|
foundAccount = a
|
||||||
} else if dbErr != db.ErrNoEntries {
|
} else if dbErr != db.ErrNoEntries {
|
||||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
|
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
|
||||||
return
|
return
|
||||||
|
@ -203,7 +225,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
|
|
||||||
// we may also have some extra information already, like the account we had in the db, or the
|
// we may also have some extra information already, like the account we had in the db, or the
|
||||||
// accountable representation that we dereferenced from remote
|
// accountable representation that we dereferenced from remote
|
||||||
if remoteAccount == nil {
|
if foundAccount == nil {
|
||||||
// we still don't have the account, so deference it if we didn't earlier
|
// we still don't have the account, so deference it if we didn't earlier
|
||||||
if accountable == nil {
|
if accountable == nil {
|
||||||
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
|
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
|
||||||
|
@ -214,7 +236,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
}
|
}
|
||||||
|
|
||||||
// then convert
|
// then convert
|
||||||
remoteAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)
|
foundAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err)
|
err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -227,18 +249,18 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err)
|
err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
remoteAccount.ID = ulid
|
foundAccount.ID = ulid
|
||||||
|
|
||||||
_, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking)
|
_, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err)
|
err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAccount.LastWebfingeredAt = fingered
|
foundAccount.LastWebfingeredAt = fingered
|
||||||
remoteAccount.UpdatedAt = time.Now()
|
foundAccount.UpdatedAt = time.Now()
|
||||||
|
|
||||||
err = d.db.Put(ctx, remoteAccount)
|
err = d.db.Put(ctx, foundAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err)
|
err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -248,9 +270,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
}
|
}
|
||||||
|
|
||||||
// we had the account already, but now we know the account domain, so update it if it's different
|
// we had the account already, but now we know the account domain, so update it if it's different
|
||||||
if !strings.EqualFold(remoteAccount.Domain, accountDomain) {
|
if !strings.EqualFold(foundAccount.Domain, accountDomain) {
|
||||||
remoteAccount.Domain = accountDomain
|
foundAccount.Domain = accountDomain
|
||||||
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount)
|
foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err)
|
err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -260,7 +282,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
// make sure the account fields are populated before returning:
|
// make sure the account fields are populated before returning:
|
||||||
// the caller might want to block until everything is loaded
|
// the caller might want to block until everything is loaded
|
||||||
var fieldsChanged bool
|
var fieldsChanged bool
|
||||||
fieldsChanged, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking)
|
fieldsChanged, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err)
|
return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -268,12 +290,12 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
||||||
var fingeredChanged bool
|
var fingeredChanged bool
|
||||||
if !fingered.IsZero() {
|
if !fingered.IsZero() {
|
||||||
fingeredChanged = true
|
fingeredChanged = true
|
||||||
remoteAccount.LastWebfingeredAt = fingered
|
foundAccount.LastWebfingeredAt = fingered
|
||||||
}
|
}
|
||||||
|
|
||||||
if fieldsChanged || fingeredChanged {
|
if fieldsChanged || fingeredChanged {
|
||||||
remoteAccount.UpdatedAt = time.Now()
|
foundAccount.UpdatedAt = time.Now()
|
||||||
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount)
|
foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err)
|
return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,9 +21,11 @@ package dereferencing_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
|
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
|
||||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||||
)
|
)
|
||||||
|
@ -42,11 +44,11 @@ func (suite *AccountTestSuite) TestDereferenceGroup() {
|
||||||
})
|
})
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
suite.NotNil(group)
|
suite.NotNil(group)
|
||||||
suite.NotNil(group)
|
|
||||||
|
|
||||||
// group values should be set
|
// group values should be set
|
||||||
suite.Equal("https://unknown-instance.com/groups/some_group", group.URI)
|
suite.Equal("https://unknown-instance.com/groups/some_group", group.URI)
|
||||||
suite.Equal("https://unknown-instance.com/@some_group", group.URL)
|
suite.Equal("https://unknown-instance.com/@some_group", group.URL)
|
||||||
|
suite.WithinDuration(time.Now(), group.LastWebfingeredAt, 5*time.Second)
|
||||||
|
|
||||||
// group should be in the database
|
// group should be in the database
|
||||||
dbGroup, err := suite.db.GetAccountByURI(context.Background(), group.URI)
|
dbGroup, err := suite.db.GetAccountByURI(context.Background(), group.URI)
|
||||||
|
@ -65,11 +67,11 @@ func (suite *AccountTestSuite) TestDereferenceService() {
|
||||||
})
|
})
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
suite.NotNil(service)
|
suite.NotNil(service)
|
||||||
suite.NotNil(service)
|
|
||||||
|
|
||||||
// service values should be set
|
// service values should be set
|
||||||
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URI)
|
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URI)
|
||||||
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URL)
|
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URL)
|
||||||
|
suite.WithinDuration(time.Now(), service.LastWebfingeredAt, 5*time.Second)
|
||||||
|
|
||||||
// service should be in the database
|
// service should be in the database
|
||||||
dbService, err := suite.db.GetAccountByURI(context.Background(), service.URI)
|
dbService, err := suite.db.GetAccountByURI(context.Background(), service.URI)
|
||||||
|
@ -79,6 +81,102 @@ func (suite *AccountTestSuite) TestDereferenceService() {
|
||||||
suite.Equal("example.org", dbService.Domain)
|
suite.Equal("example.org", dbService.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
We shouldn't try webfingering or making http calls to dereference local accounts
|
||||||
|
that might be passed into GetRemoteAccount for whatever reason, so these tests are
|
||||||
|
here to make sure that such cases are (basically) short-circuit evaluated and given
|
||||||
|
back as-is without trying to make any calls to one's own instance.
|
||||||
|
*/
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsRemoteURL() {
|
||||||
|
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: fetchingAccount.Username,
|
||||||
|
RemoteAccountID: testrig.URLMustParse(targetAccount.URI),
|
||||||
|
})
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(fetchedAccount)
|
||||||
|
suite.Empty(fetchedAccount.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsername() {
|
||||||
|
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: fetchingAccount.Username,
|
||||||
|
RemoteAccountUsername: targetAccount.Username,
|
||||||
|
})
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(fetchedAccount)
|
||||||
|
suite.Empty(fetchedAccount.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomain() {
|
||||||
|
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: fetchingAccount.Username,
|
||||||
|
RemoteAccountUsername: targetAccount.Username,
|
||||||
|
RemoteAccountHost: config.GetHost(),
|
||||||
|
})
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(fetchedAccount)
|
||||||
|
suite.Empty(fetchedAccount.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomainAndURL() {
|
||||||
|
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
targetAccount := suite.testAccounts["local_account_2"]
|
||||||
|
|
||||||
|
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: fetchingAccount.Username,
|
||||||
|
RemoteAccountID: testrig.URLMustParse(targetAccount.URI),
|
||||||
|
RemoteAccountUsername: targetAccount.Username,
|
||||||
|
RemoteAccountHost: config.GetHost(),
|
||||||
|
})
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.NotNil(fetchedAccount)
|
||||||
|
suite.Empty(fetchedAccount.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername() {
|
||||||
|
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
|
||||||
|
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: fetchingAccount.Username,
|
||||||
|
RemoteAccountUsername: "thisaccountdoesnotexist",
|
||||||
|
})
|
||||||
|
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||||
|
suite.Nil(fetchedAccount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDomain() {
|
||||||
|
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
|
||||||
|
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: fetchingAccount.Username,
|
||||||
|
RemoteAccountUsername: "thisaccountdoesnotexist",
|
||||||
|
RemoteAccountHost: "localhost:8080",
|
||||||
|
})
|
||||||
|
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||||
|
suite.Nil(fetchedAccount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() {
|
||||||
|
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||||
|
|
||||||
|
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: fetchingAccount.Username,
|
||||||
|
RemoteAccountID: testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"),
|
||||||
|
})
|
||||||
|
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||||
|
suite.Nil(fetchedAccount)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountTestSuite(t *testing.T) {
|
func TestAccountTestSuite(t *testing.T) {
|
||||||
suite.Run(t, new(AccountTestSuite))
|
suite.Run(t, new(AccountTestSuite))
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,6 @@ import (
|
||||||
|
|
||||||
func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) {
|
func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) {
|
||||||
l := log.WithFields(kv.Fields{
|
l := log.WithFields(kv.Fields{
|
||||||
|
|
||||||
{"query", search.Query},
|
{"query", search.Query},
|
||||||
}...)
|
}...)
|
||||||
|
|
||||||
|
@ -62,7 +61,7 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
|
||||||
|
|
||||||
/*
|
/*
|
||||||
SEARCH BY MENTION
|
SEARCH BY MENTION
|
||||||
check if the query is something like @whatever_username@example.org -- this means it's a remote account
|
check if the query is something like @whatever_username@example.org -- this means it's likely a remote account
|
||||||
*/
|
*/
|
||||||
maybeNamestring := query
|
maybeNamestring := query
|
||||||
if maybeNamestring[0] != '@' {
|
if maybeNamestring[0] != '@' {
|
||||||
|
@ -135,7 +134,6 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
|
||||||
|
|
||||||
func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) {
|
func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) {
|
||||||
l := log.WithFields(kv.Fields{
|
l := log.WithFields(kv.Fields{
|
||||||
|
|
||||||
{"uri", uri.String()},
|
{"uri", uri.String()},
|
||||||
{"resolve", resolve},
|
{"resolve", resolve},
|
||||||
}...)
|
}...)
|
||||||
|
@ -161,67 +159,46 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) {
|
func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) {
|
||||||
if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil {
|
// it might be a web url like http://example.org/@user instead
|
||||||
return maybeAccount, nil
|
// of an AP uri like http://example.org/users/user, check first
|
||||||
} else if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil {
|
if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil {
|
||||||
return maybeAccount, nil
|
return maybeAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if resolve {
|
if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() {
|
||||||
// we don't have it locally so try and dereference it
|
// this is a local account; if we don't have it now then
|
||||||
account, err := p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
// we should just bail instead of trying to get it remote
|
||||||
RequestingUsername: authed.Account.Username,
|
if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil {
|
||||||
RemoteAccountID: uri,
|
return maybeAccount, nil
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("searchAccountByURI: error dereferencing account with uri %s: %s", uri.String(), err)
|
|
||||||
}
|
}
|
||||||
return account, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, nil
|
|
||||||
|
// we don't have it yet, try to find it remotely
|
||||||
|
return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
||||||
|
RequestingUsername: authed.Account.Username,
|
||||||
|
RemoteAccountID: uri,
|
||||||
|
Blocking: true,
|
||||||
|
SkipResolve: !resolve,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) {
|
func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) {
|
||||||
maybeAcct := >smodel.Account{}
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// if it's a local account we can skip a whole bunch of stuff
|
// if it's a local account we can skip a whole bunch of stuff
|
||||||
if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" {
|
if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" {
|
||||||
maybeAcct, err = p.db.GetLocalAccountByUsername(ctx, username)
|
maybeAcct, err := p.db.GetLocalAccountByUsername(ctx, username)
|
||||||
if err != nil && err != db.ErrNoEntries {
|
if err == nil || err == db.ErrNoEntries {
|
||||||
return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err)
|
return maybeAcct, nil
|
||||||
}
|
}
|
||||||
return maybeAcct, nil
|
return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// it's not a local account so first we'll check if it's in the database already...
|
// we don't have it yet, try to find it remotely
|
||||||
where := []db.Where{
|
return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
||||||
{Key: "username", Value: username, CaseInsensitive: true},
|
RequestingUsername: authed.Account.Username,
|
||||||
{Key: "domain", Value: domain, CaseInsensitive: true},
|
RemoteAccountUsername: username,
|
||||||
}
|
RemoteAccountHost: domain,
|
||||||
err = p.db.GetWhere(ctx, where, maybeAcct)
|
Blocking: true,
|
||||||
if err == nil {
|
SkipResolve: !resolve,
|
||||||
// we've got it stored locally already!
|
})
|
||||||
return maybeAcct, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != db.ErrNoEntries {
|
|
||||||
// if it's not errNoEntries there's been a real database error so bail at this point
|
|
||||||
return nil, fmt.Errorf("searchAccountByMention: database error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// we got a db.ErrNoEntries, so we just don't have the account locally stored -- check if we can dereference it
|
|
||||||
if resolve {
|
|
||||||
maybeAcct, err = p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
|
||||||
RequestingUsername: authed.Account.Username,
|
|
||||||
RemoteAccountUsername: username,
|
|
||||||
RemoteAccountHost: domain,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("searchAccountByMention: error getting remote account: %s", err)
|
|
||||||
}
|
|
||||||
return maybeAcct, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue