diff --git a/internal/api/client/account/accountupdate.go b/internal/api/client/account/accountupdate.go index c38ede252..9a377f3b8 100644 --- a/internal/api/client/account/accountupdate.go +++ b/internal/api/client/account/accountupdate.go @@ -19,7 +19,9 @@ package account import ( + "fmt" "net/http" + "strconv" "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -107,17 +109,24 @@ func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) { } l.Tracef("retrieved account %+v", authed.Account.ID) - form := &model.UpdateCredentialsRequest{} - if err := c.ShouldBind(&form); err != nil || form == nil { - l.Debugf("could not parse form from request: %s", err) + form, err := parseUpdateAccountForm(c) + if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - l.Debugf("parsed request form %+v", form) - // if everything on the form is nil, then nothing has been set and we shouldn't continue - if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil { + if form.Discoverable == nil && + form.Bot == nil && + form.DisplayName == nil && + form.Note == nil && + form.Avatar == nil && + form.Header == nil && + form.Locked == nil && + form.Source.Privacy == nil && + form.Source.Sensitive == nil && + form.Source.Language == nil && + form.FieldsAttributes == nil { l.Debugf("could not parse form from request") c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"}) return @@ -133,3 +142,34 @@ func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) { l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) c.JSON(http.StatusOK, acctSensitive) } + +func parseUpdateAccountForm(c *gin.Context) (*model.UpdateCredentialsRequest, error) { + // parse main fields from request + form := &model.UpdateCredentialsRequest{ + Source: &model.UpdateSource{}, + } + if err := c.ShouldBind(&form); err != nil || form == nil { + return nil, fmt.Errorf("could not parse form from request: %s", err) + } + + // parse source field-by-field + sourceMap := c.PostFormMap("source") + + if privacy, ok := sourceMap["privacy"]; ok { + form.Source.Privacy = &privacy + } + + if sensitive, ok := sourceMap["sensitive"]; ok { + sensitiveBool, err := strconv.ParseBool(sensitive) + if err != nil { + return nil, fmt.Errorf("error parsing form source[sensitive]: %s", err) + } + form.Source.Sensitive = &sensitiveBool + } + + if language, ok := sourceMap["language"]; ok { + form.Source.Language = &language + } + + return form, nil +} diff --git a/internal/api/client/account/accountupdate_test.go b/internal/api/client/account/accountupdate_test.go index a02573631..bafda0e01 100644 --- a/internal/api/client/account/accountupdate_test.go +++ b/internal/api/client/account/accountupdate_test.go @@ -19,8 +19,8 @@ package account_test import ( + "context" "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -37,22 +37,224 @@ type AccountUpdateTestSuite struct { AccountStandardTestSuite } -func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerSimple() { +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler() { // set up the request - // we're updating the header image, the display name, and the locked status of zork - // we're removing the note/bio + // we're updating the note of zork + newBio := "this is my new bio read it and weep" requestBody, w, err := testrig.CreateMultipartFormData( - "header", "../../../../testrig/media/test-jpeg.jpg", + "", "", map[string]string{ - "display_name": "updated zork display name!!!", - "note": "", - "locked": "true", + "note": newBio, }) if err != nil { panic(err) } + bodyBytes := requestBody.Bytes() recorder := httptest.NewRecorder() - ctx := suite.newContext(recorder, http.MethodPatch, requestBody.Bytes(), account.UpdateCredentialsPath, w.FormDataContentType()) + ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) + + // call the handler + suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) + + // 1. we should have OK because our request was valid + suite.Equal(http.StatusOK, recorder.Code) + + // 2. we should have no error message in the result body + result := recorder.Result() + defer result.Body.Close() + + // check the response + b, err := ioutil.ReadAll(result.Body) + assert.NoError(suite.T(), err) + + // unmarshal the returned account + apimodelAccount := &apimodel.Account{} + err = json.Unmarshal(b, apimodelAccount) + suite.NoError(err) + + // check the returned api model account + // fields should be updated + suite.Equal("

this is my new bio read it and weep

", apimodelAccount.Note) +} + +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerUnlockLock() { + // set up the first request + requestBody1, w1, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "locked": "false", + }) + if err != nil { + panic(err) + } + bodyBytes1 := requestBody1.Bytes() + recorder1 := httptest.NewRecorder() + ctx1 := suite.newContext(recorder1, http.MethodPatch, bodyBytes1, account.UpdateCredentialsPath, w1.FormDataContentType()) + + // call the handler + suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx1) + + // 1. we should have OK because our request was valid + suite.Equal(http.StatusOK, recorder1.Code) + + // 2. we should have no error message in the result body + result1 := recorder1.Result() + defer result1.Body.Close() + + // check the response + b1, err := ioutil.ReadAll(result1.Body) + assert.NoError(suite.T(), err) + + // unmarshal the returned account + apimodelAccount1 := &apimodel.Account{} + err = json.Unmarshal(b1, apimodelAccount1) + suite.NoError(err) + + // check the returned api model account + // fields should be updated + suite.False(apimodelAccount1.Locked) + + // set up the first request + requestBody2, w2, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "locked": "true", + }) + if err != nil { + panic(err) + } + bodyBytes2 := requestBody2.Bytes() + recorder2 := httptest.NewRecorder() + ctx2 := suite.newContext(recorder2, http.MethodPatch, bodyBytes2, account.UpdateCredentialsPath, w2.FormDataContentType()) + + // call the handler + suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx2) + + // 1. we should have OK because our request was valid + suite.Equal(http.StatusOK, recorder1.Code) + + // 2. we should have no error message in the result body + result2 := recorder2.Result() + defer result2.Body.Close() + + // check the response + b2, err := ioutil.ReadAll(result2.Body) + suite.NoError(err) + + // unmarshal the returned account + apimodelAccount2 := &apimodel.Account{} + err = json.Unmarshal(b2, apimodelAccount2) + suite.NoError(err) + + // check the returned api model account + // fields should be updated + suite.True(apimodelAccount2.Locked) +} + +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerGetAccountFirst() { + // get the account first to make sure it's in the database cache -- when the account is updated via + // the PATCH handler, it should invalidate the cache and not return the old version + _, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) + suite.NoError(err) + + // set up the request + // we're updating the note of zork + newBio := "this is my new bio read it and weep" + requestBody, w, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "note": newBio, + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + recorder := httptest.NewRecorder() + ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) + + // call the handler + suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) + + // 1. we should have OK because our request was valid + suite.Equal(http.StatusOK, recorder.Code) + + // 2. we should have no error message in the result body + result := recorder.Result() + defer result.Body.Close() + + // check the response + b, err := ioutil.ReadAll(result.Body) + assert.NoError(suite.T(), err) + + // unmarshal the returned account + apimodelAccount := &apimodel.Account{} + err = json.Unmarshal(b, apimodelAccount) + suite.NoError(err) + + // check the returned api model account + // fields should be updated + suite.Equal("

this is my new bio read it and weep

", apimodelAccount.Note) +} + +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerTwoFields() { + // set up the request + // we're updating the note of zork, and setting locked to true + newBio := "this is my new bio read it and weep" + requestBody, w, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "note": newBio, + "locked": "true", + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + recorder := httptest.NewRecorder() + ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) + + // call the handler + suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) + + // 1. we should have OK because our request was valid + suite.Equal(http.StatusOK, recorder.Code) + + // 2. we should have no error message in the result body + result := recorder.Result() + defer result.Body.Close() + + // check the response + b, err := ioutil.ReadAll(result.Body) + assert.NoError(suite.T(), err) + + // unmarshal the returned account + apimodelAccount := &apimodel.Account{} + err = json.Unmarshal(b, apimodelAccount) + suite.NoError(err) + + // check the returned api model account + // fields should be updated + suite.Equal("

this is my new bio read it and weep

", apimodelAccount.Note) + suite.True(apimodelAccount.Locked) +} + +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerWithMedia() { + // set up the request + // we're updating the header image, the display name, and the locked status of zork + // we're removing the note/bio + requestBody, w, err := testrig.CreateMultipartFormData( + "header", "../../../../testrig/media/test-jpeg.jpg", + map[string]string{ + "display_name": "updated zork display name!!!", + "note": "", + "locked": "true", + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + recorder := httptest.NewRecorder() + ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) // call the handler suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) @@ -67,7 +269,6 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerSim // check the response b, err := ioutil.ReadAll(result.Body) assert.NoError(suite.T(), err) - fmt.Println(string(b)) // unmarshal the returned account apimodelAccount := &apimodel.Account{} @@ -90,6 +291,74 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerSim suite.NotEqual("http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpeg", apimodelAccount.HeaderStatic) } +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerEmptyForm() { + // set up the request + bodyBytes := []byte{} + recorder := httptest.NewRecorder() + ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, "") + + // call the handler + suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) + + // 1. we should have OK because our request was valid + suite.Equal(http.StatusBadRequest, recorder.Code) + + // 2. we should have no error message in the result body + result := recorder.Result() + defer result.Body.Close() + + // check the response + b, err := ioutil.ReadAll(result.Body) + assert.NoError(suite.T(), err) + suite.Equal(`{"error":"empty form submitted"}`, string(b)) +} + +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandlerUpdateSource() { + // set up the request + // we're updating the language of zork + newLanguage := "de" + requestBody, w, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "source[privacy]": string(apimodel.VisibilityPrivate), + "source[language]": "de", + "source[sensitive]": "true", + "locked": "true", + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + recorder := httptest.NewRecorder() + ctx := suite.newContext(recorder, http.MethodPatch, bodyBytes, account.UpdateCredentialsPath, w.FormDataContentType()) + + // call the handler + suite.accountModule.AccountUpdateCredentialsPATCHHandler(ctx) + + // 1. we should have OK because our request was valid + suite.Equal(http.StatusOK, recorder.Code) + + // 2. we should have no error message in the result body + result := recorder.Result() + defer result.Body.Close() + + // check the response + b, err := ioutil.ReadAll(result.Body) + assert.NoError(suite.T(), err) + + // unmarshal the returned account + apimodelAccount := &apimodel.Account{} + err = json.Unmarshal(b, apimodelAccount) + suite.NoError(err) + + // check the returned api model account + // fields should be updated + suite.Equal(newLanguage, apimodelAccount.Source.Language) + suite.EqualValues(apimodel.VisibilityPrivate, apimodelAccount.Source.Privacy) + suite.True(apimodelAccount.Source.Sensitive) + suite.True(apimodelAccount.Locked) +} + func TestAccountUpdateTestSuite(t *testing.T) { suite.Run(t, new(AccountUpdateTestSuite)) } diff --git a/internal/api/model/status.go b/internal/api/model/status.go index c5b5a4640..8be1a4870 100644 --- a/internal/api/model/status.go +++ b/internal/api/model/status.go @@ -160,6 +160,7 @@ type StatusCreateRequest struct { // - public // - unlisted // - private +// - mutuals_only // - direct type Visibility string diff --git a/internal/db/basic.go b/internal/db/basic.go index d94c98e45..44000ef24 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -66,10 +66,6 @@ type Basic interface { // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. UpdateByPrimaryKey(ctx context.Context, i interface{}) Error - // UpdateOneByPrimaryKey sets one column of interface, with the given key, to the given value. - // It uses the primary key of interface i to decide which row to update. This is usually the `id`. - UpdateOneByPrimaryKey(ctx context.Context, key string, value interface{}, i interface{}) Error - // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 32a70f7cd..745e41567 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -22,7 +22,6 @@ import ( "context" "errors" "fmt" - "strings" "time" "github.com/superseriousbusiness/gotosocial/internal/cache" @@ -103,16 +102,15 @@ func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.A } func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { - if strings.TrimSpace(account.ID) == "" { - // TODO: we should not need this check here - return nil, errors.New("account had no ID") - } - - // Update the account's last-used + // Update the account's last-updated account.UpdatedAt = time.Now() // Update the account model in the DB - _, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx) + _, err := a.conn. + NewUpdate(). + Model(account). + WherePK(). + Exec(ctx) if err != nil { return nil, a.conn.ProcessError(err) } diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 1e7880379..e5a1fbaf9 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -105,16 +105,6 @@ func (b *basicDB) UpdateByPrimaryKey(ctx context.Context, i interface{}) db.Erro return b.conn.ProcessError(err) } -func (b *basicDB) UpdateOneByPrimaryKey(ctx context.Context, key string, value interface{}, i interface{}) db.Error { - q := b.conn.NewUpdate(). - Model(i). - Set("? = ?", bun.Safe(key), value). - WherePK() - - _, err := q.Exec(ctx) - return b.conn.ProcessError(err) -} - func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { q := b.conn.NewUpdate().Model(i) diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go index acdfb6640..e5f7e159a 100644 --- a/internal/db/bundb/basic_test.go +++ b/internal/db/bundb/basic_test.go @@ -64,40 +64,6 @@ func (suite *BasicTestSuite) TestGetAllNotNull() { } } -func (suite *BasicTestSuite) TestUpdateOneByPrimaryKeySetEmpty() { - testAccount := suite.testAccounts["local_account_1"] - - // try removing the note from zork - err := suite.db.UpdateOneByPrimaryKey(context.Background(), "note", "", testAccount) - suite.NoError(err) - - // get zork out of the database - dbAccount, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) - suite.NoError(err) - suite.NotNil(dbAccount) - - // note should be empty now - suite.Empty(dbAccount.Note) -} - -func (suite *BasicTestSuite) TestUpdateOneByPrimaryKeySetValue() { - testAccount := suite.testAccounts["local_account_1"] - - note := "this is my new note :)" - - // try updating the note on zork - err := suite.db.UpdateOneByPrimaryKey(context.Background(), "note", note, testAccount) - suite.NoError(err) - - // get zork out of the database - dbAccount, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) - suite.NoError(err) - suite.NotNil(dbAccount) - - // note should be set now - suite.Equal(note, dbAccount.Note) -} - func TestBasicTestSuite(t *testing.T) { suite.Run(t, new(BasicTestSuite)) } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 7ddcab5c7..400535da7 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -147,6 +147,11 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) } + if log.Level >= logrus.TraceLevel { + // add a hook to just log queries and the time they take + conn.DB.AddQueryHook(newDebugQueryHook(log)) + } + // actually *begin* the connection so that we can tell if the db is there and listening if err := conn.Ping(); err != nil { return nil, fmt.Errorf("db connection error: %s", err) @@ -402,7 +407,7 @@ func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAcco return menchies, nil } -func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { +func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) { newTags := []*gtsmodel.Tag{} for _, t := range tags { tag := >smodel.Tag{} @@ -438,7 +443,7 @@ func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, ori return newTags, nil } -func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { +func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string) ([]*gtsmodel.Emoji, error) { newEmojis := []*gtsmodel.Emoji{} for _, e := range emojis { emoji := >smodel.Emoji{} diff --git a/internal/db/bundb/trace.go b/internal/db/bundb/trace.go new file mode 100644 index 000000000..e62e8c01f --- /dev/null +++ b/internal/db/bundb/trace.go @@ -0,0 +1,53 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package bundb + +import ( + "context" + "time" + + "github.com/sirupsen/logrus" + "github.com/uptrace/bun" +) + +func newDebugQueryHook(log *logrus.Logger) bun.QueryHook { + return &debugQueryHook{ + log: log, + } +} + +// debugQueryHook implements bun.QueryHook +type debugQueryHook struct { + log *logrus.Logger +} + +func (q *debugQueryHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context { + // do nothing + return ctx +} + +// AfterQuery logs the time taken to query, the operation (select, update, etc), and the query itself as translated by bun. +func (q *debugQueryHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) { + dur := time.Since(event.StartTime).Round(time.Microsecond) + l := q.log.WithFields(logrus.Fields{ + "queryTime": dur, + "operation": event.Operation(), + }) + l.Trace(event.Query) +} diff --git a/internal/db/db.go b/internal/db/db.go index ec94fcfe7..9a93322cb 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -64,7 +64,7 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking // if they exist in the db already, and conveniently returning them, or creating new tag structs. - TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) + TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) // EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been // used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then @@ -72,5 +72,5 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking // if they exist in the db and conveniently returning them if they do. - EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) + EmojiStringsToEmojis(ctx context.Context, emojis []string) ([]*gtsmodel.Emoji, error) } diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go index 71b876d3b..831607d94 100644 --- a/internal/processing/account/account.go +++ b/internal/processing/account/account.go @@ -32,6 +32,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/oauth2/v4" @@ -83,6 +84,7 @@ type processor struct { fromClientAPI chan messages.FromClientAPI oauthServer oauth.Server filter visibility.Filter + formatter text.Formatter db db.DB federator federation.Federator log *logrus.Logger @@ -97,6 +99,7 @@ func New(db db.DB, tc typeutils.TypeConverter, mediaHandler media.Handler, oauth fromClientAPI: fromClientAPI, oauthServer: oauthServer, filter: visibility.NewFilter(db, log), + formatter: text.NewFormatter(config, db, log), db: db, federator: federator, log: log, diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index e997a95c7..6dc288849 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -32,6 +32,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/text" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/validate" ) @@ -39,35 +40,29 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form l := p.log.WithField("func", "AccountUpdate") if form.Discoverable != nil { - if err := p.db.UpdateOneByPrimaryKey(ctx, "discoverable", *form.Discoverable, account); err != nil { - return nil, fmt.Errorf("error updating discoverable: %s", err) - } + account.Discoverable = *form.Discoverable } if form.Bot != nil { - if err := p.db.UpdateOneByPrimaryKey(ctx, "bot", *form.Bot, account); err != nil { - return nil, fmt.Errorf("error updating bot: %s", err) - } + account.Bot = *form.Bot } if form.DisplayName != nil { if err := validate.DisplayName(*form.DisplayName); err != nil { return nil, err } - displayName := text.RemoveHTML(*form.DisplayName) // no html allowed in display name - if err := p.db.UpdateOneByPrimaryKey(ctx, "display_name", displayName, account); err != nil { - return nil, err - } + account.DisplayName = text.RemoveHTML(*form.DisplayName) } if form.Note != nil { if err := validate.Note(*form.Note); err != nil { return nil, err } - note := text.SanitizeHTML(*form.Note) // html OK in note but sanitize it - if err := p.db.UpdateOneByPrimaryKey(ctx, "note", note, account); err != nil { + note, err := p.processNote(ctx, *form.Note, account.ID) + if err != nil { return nil, err } + account.Note = note } if form.Avatar != nil && form.Avatar.Size != 0 { @@ -75,6 +70,8 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form if err != nil { return nil, err } + account.AvatarMediaAttachmentID = avatarInfo.ID + account.AvatarMediaAttachment = avatarInfo l.Tracef("new avatar info for account %s is %+v", account.ID, avatarInfo) } @@ -83,13 +80,13 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form if err != nil { return nil, err } + account.HeaderMediaAttachmentID = headerInfo.ID + account.HeaderMediaAttachment = headerInfo l.Tracef("new header info for account %s is %+v", account.ID, headerInfo) } if form.Locked != nil { - if err := p.db.UpdateOneByPrimaryKey(ctx, "locked", *form.Locked, account); err != nil { - return nil, err - } + account.Locked = *form.Locked } if form.Source != nil { @@ -97,31 +94,25 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form if err := validate.Language(*form.Source.Language); err != nil { return nil, err } - if err := p.db.UpdateOneByPrimaryKey(ctx, "language", *form.Source.Language, account); err != nil { - return nil, err - } + account.Language = *form.Source.Language } if form.Source.Sensitive != nil { - if err := p.db.UpdateOneByPrimaryKey(ctx, "locked", *form.Locked, account); err != nil { - return nil, err - } + account.Sensitive = *form.Source.Sensitive } if form.Source.Privacy != nil { if err := validate.Privacy(*form.Source.Privacy); err != nil { return nil, err } - if err := p.db.UpdateOneByPrimaryKey(ctx, "privacy", *form.Source.Privacy, account); err != nil { - return nil, err - } + privacy := p.tc.MastoVisToVis(apimodel.Visibility(*form.Source.Privacy)) + account.Privacy = privacy } } - // fetch the account with all updated values set - updatedAccount, err := p.db.GetAccountByID(ctx, account.ID) + updatedAccount, err := p.db.UpdateAccount(ctx, account) if err != nil { - return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err) + return nil, fmt.Errorf("could not update account %s: %s", account.ID, err) } p.fromClientAPI <- messages.FromClientAPI{ @@ -203,3 +194,27 @@ func (p *processor) UpdateHeader(ctx context.Context, header *multipart.FileHead return headerInfo, f.Close() } + +func (p *processor) processNote(ctx context.Context, note string, accountID string) (string, error) { + if note == "" { + return "", nil + } + + tagStrings := util.DeriveHashtagsFromText(note) + tags, err := p.db.TagStringsToTags(ctx, tagStrings, accountID) + if err != nil { + return "", err + } + + mentionStrings := util.DeriveMentionsFromText(note) + mentions, err := p.db.MentionStringsToMentions(ctx, mentionStrings, accountID, "") + if err != nil { + return "", err + } + + // TODO: support emojis in account notes + // emojiStrings := util.DeriveEmojisFromText(note) + // emojis, err := p.db.EmojiStringsToEmojis(ctx, emojiStrings) + + return p.formatter.FromPlain(ctx, note, mentions, tags), nil +} diff --git a/internal/processing/account/update_test.go b/internal/processing/account/update_test.go index b18a5e42e..63370cd39 100644 --- a/internal/processing/account/update_test.go +++ b/internal/processing/account/update_test.go @@ -36,7 +36,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateSimple() { locked := true displayName := "new display name" - note := "" + note := "#hello here i am!" form := &apimodel.UpdateCredentialsRequest{ DisplayName: &displayName, @@ -52,7 +52,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateSimple() { // fields on the profile should be updated suite.True(apiAccount.Locked) suite.Equal(displayName, apiAccount.DisplayName) - suite.Empty(apiAccount.Note) + suite.Equal(`

#hello here i am!

`, apiAccount.Note) // we should have an update in the client api channel msg := <-suite.fromClientAPIChan @@ -67,7 +67,50 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateSimple() { suite.NoError(err) suite.True(dbAccount.Locked) suite.Equal(displayName, dbAccount.DisplayName) - suite.Empty(dbAccount.Note) + suite.Equal(`

#hello here i am!

`, dbAccount.Note) +} + +func (suite *AccountUpdateTestSuite) TestAccountUpdateWithMention() { + testAccount := suite.testAccounts["local_account_1"] + + locked := true + displayName := "new display name" + note := `#hello here i am! + +go check out @1happyturtle, they have a cool account! +` + noteExpected := `

#hello here i am!

go check out @1happyturtle, they have a cool account!

` + + form := &apimodel.UpdateCredentialsRequest{ + DisplayName: &displayName, + Locked: &locked, + Note: ¬e, + } + + // should get no error from the update function, and an api model account returned + apiAccount, err := suite.accountProcessor.Update(context.Background(), testAccount, form) + suite.NoError(err) + suite.NotNil(apiAccount) + + // fields on the profile should be updated + suite.True(apiAccount.Locked) + suite.Equal(displayName, apiAccount.DisplayName) + suite.Equal(noteExpected, apiAccount.Note) + + // we should have an update in the client api channel + msg := <-suite.fromClientAPIChan + suite.Equal(ap.ActivityUpdate, msg.APActivityType) + suite.Equal(ap.ObjectProfile, msg.APObjectType) + suite.NotNil(msg.OriginAccount) + suite.Equal(testAccount.ID, msg.OriginAccount.ID) + suite.Nil(msg.TargetAccount) + + // fields should be updated in the database as well + dbAccount, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) + suite.NoError(err) + suite.True(dbAccount.Locked) + suite.Equal(displayName, dbAccount.DisplayName) + suite.Equal(noteExpected, dbAccount.Note) } func TestAccountUpdateTestSuite(t *testing.T) { diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go index 5ed63d919..edbb9a31a 100644 --- a/internal/processing/status/util.go +++ b/internal/processing/status/util.go @@ -192,7 +192,7 @@ func (p *processor) ProcessLanguage(ctx context.Context, form *apimodel.Advanced func (p *processor) ProcessMentions(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { menchies := []string{} - gtsMenchies, err := p.db.MentionStringsToMentions(ctx, util.DeriveMentionsFromStatus(form.Status), accountID, status.ID) + gtsMenchies, err := p.db.MentionStringsToMentions(ctx, util.DeriveMentionsFromText(form.Status), accountID, status.ID) if err != nil { return fmt.Errorf("error generating mentions from status: %s", err) } @@ -217,7 +217,7 @@ func (p *processor) ProcessMentions(ctx context.Context, form *apimodel.Advanced func (p *processor) ProcessTags(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { tags := []string{} - gtsTags, err := p.db.TagStringsToTags(ctx, util.DeriveHashtagsFromStatus(form.Status), accountID, status.ID) + gtsTags, err := p.db.TagStringsToTags(ctx, util.DeriveHashtagsFromText(form.Status), accountID) if err != nil { return fmt.Errorf("error generating hashtags from status: %s", err) } @@ -236,7 +236,7 @@ func (p *processor) ProcessTags(ctx context.Context, form *apimodel.AdvancedStat func (p *processor) ProcessEmojis(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { emojis := []string{} - gtsEmojis, err := p.db.EmojiStringsToEmojis(ctx, util.DeriveEmojisFromStatus(form.Status), accountID, status.ID) + gtsEmojis, err := p.db.EmojiStringsToEmojis(ctx, util.DeriveEmojisFromText(form.Status)) if err != nil { return fmt.Errorf("error generating emojis from status: %s", err) } diff --git a/internal/typeutils/converter.go b/internal/typeutils/converter.go index 630e48300..40cb5b969 100644 --- a/internal/typeutils/converter.go +++ b/internal/typeutils/converter.go @@ -178,20 +178,18 @@ type TypeConverter interface { } type converter struct { - config *config.Config - db db.DB - log *logrus.Logger - frontendCache cache.Cache - asCache cache.Cache + config *config.Config + db db.DB + log *logrus.Logger + asCache cache.Cache } // NewConverter returns a new Converter func NewConverter(config *config.Config, db db.DB, log *logrus.Logger) TypeConverter { return &converter{ - config: config, - db: db, - log: log, - frontendCache: cache.New(), - asCache: cache.New(), + config: config, + db: db, + log: log, + asCache: cache.New(), } } diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 7924e2185..67d6cef94 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -67,14 +67,6 @@ func (c *converter) AccountToMastoPublic(ctx context.Context, a *gtsmodel.Accoun return nil, fmt.Errorf("given account was nil") } - // first check if we have this account in our frontEnd cache - if accountI, err := c.frontendCache.Fetch(a.ID); err == nil { - if account, ok := accountI.(*model.Account); ok { - // we have it, so just return it as-is - return account, nil - } - } - // count followers followersCount, err := c.db.CountAccountFollowedBy(ctx, a.ID, false) if err != nil { @@ -184,11 +176,6 @@ func (c *converter) AccountToMastoPublic(ctx context.Context, a *gtsmodel.Accoun Suspended: suspended, } - // put the account in our cache in case we need it again soon - if err := c.frontendCache.Store(a.ID, accountFrontend); err != nil { - return nil, err - } - return accountFrontend, nil } diff --git a/internal/util/statustools.go b/internal/util/statustools.go index ca18577b0..95ce63a5b 100644 --- a/internal/util/statustools.go +++ b/internal/util/statustools.go @@ -25,38 +25,38 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/regexes" ) -// DeriveMentionsFromStatus takes a plaintext (ie., not html-formatted) status, +// DeriveMentionsFromText takes a plaintext (ie., not html-formatted) text, // and applies a regex to it to return a deduplicated list of accounts -// mentioned in that status. +// mentioned in that text. // // It will look for fully-qualified account names in the form "@user@example.org". // or the form "@username" for local users. -func DeriveMentionsFromStatus(status string) []string { +func DeriveMentionsFromText(text string) []string { mentionedAccounts := []string{} - for _, m := range regexes.MentionFinder.FindAllStringSubmatch(status, -1) { + for _, m := range regexes.MentionFinder.FindAllStringSubmatch(text, -1) { mentionedAccounts = append(mentionedAccounts, m[1]) } return UniqueStrings(mentionedAccounts) } -// DeriveHashtagsFromStatus takes a plaintext (ie., not html-formatted) status, +// DeriveHashtagsFromText takes a plaintext (ie., not html-formatted) text, // and applies a regex to it to return a deduplicated list of hashtags -// used in that status, without the leading #. The case of the returned +// used in that text, without the leading #. The case of the returned // tags will be lowered, for consistency. -func DeriveHashtagsFromStatus(status string) []string { +func DeriveHashtagsFromText(text string) []string { tags := []string{} - for _, m := range regexes.HashtagFinder.FindAllStringSubmatch(status, -1) { + for _, m := range regexes.HashtagFinder.FindAllStringSubmatch(text, -1) { tags = append(tags, strings.TrimPrefix(m[1], "#")) } return UniqueStrings(tags) } -// DeriveEmojisFromStatus takes a plaintext (ie., not html-formatted) status, +// DeriveEmojisFromText takes a plaintext (ie., not html-formatted) text, // and applies a regex to it to return a deduplicated list of emojis -// used in that status, without the surround ::. -func DeriveEmojisFromStatus(status string) []string { +// used in that text, without the surrounding `::` +func DeriveEmojisFromText(text string) []string { emojis := []string{} - for _, m := range regexes.EmojiFinder.FindAllStringSubmatch(status, -1) { + for _, m := range regexes.EmojiFinder.FindAllStringSubmatch(text, -1) { emojis = append(emojis, m[1]) } return UniqueStrings(emojis) diff --git a/internal/util/statustools_test.go b/internal/util/statustools_test.go index 0ec2719f5..447315b25 100644 --- a/internal/util/statustools_test.go +++ b/internal/util/statustools_test.go @@ -45,7 +45,7 @@ func (suite *StatusTestSuite) TestDeriveMentionsOK() { ` - menchies := util.DeriveMentionsFromStatus(statusText) + menchies := util.DeriveMentionsFromText(statusText) assert.Len(suite.T(), menchies, 6) assert.Equal(suite.T(), "@dumpsterqueer@example.org", menchies[0]) assert.Equal(suite.T(), "@someone_else@testing.best-horse.com", menchies[1]) @@ -57,7 +57,7 @@ func (suite *StatusTestSuite) TestDeriveMentionsOK() { func (suite *StatusTestSuite) TestDeriveMentionsEmpty() { statusText := `` - menchies := util.DeriveMentionsFromStatus(statusText) + menchies := util.DeriveMentionsFromText(statusText) assert.Len(suite.T(), menchies, 0) } @@ -74,7 +74,7 @@ func (suite *StatusTestSuite) TestDeriveHashtagsOK() { #111111 thisalsoshouldn'twork#### ##` - tags := util.DeriveHashtagsFromStatus(statusText) + tags := util.DeriveHashtagsFromText(statusText) assert.Len(suite.T(), tags, 5) assert.Equal(suite.T(), "testing123", tags[0]) assert.Equal(suite.T(), "also", tags[1]) @@ -97,7 +97,7 @@ Here's some normal text with an :emoji: at the end :underscores_ok_too: ` - tags := util.DeriveEmojisFromStatus(statusText) + tags := util.DeriveEmojisFromText(statusText) assert.Len(suite.T(), tags, 7) assert.Equal(suite.T(), "test", tags[0]) assert.Equal(suite.T(), "another", tags[1]) @@ -115,9 +115,9 @@ func (suite *StatusTestSuite) TestDeriveMultiple() { Text` - ms := util.DeriveMentionsFromStatus(statusText) - hs := util.DeriveHashtagsFromStatus(statusText) - es := util.DeriveEmojisFromStatus(statusText) + ms := util.DeriveMentionsFromText(statusText) + hs := util.DeriveHashtagsFromText(statusText) + es := util.DeriveEmojisFromText(statusText) assert.Len(suite.T(), ms, 1) assert.Equal(suite.T(), "@foss_satan@fossbros-anonymous.io", ms[0]) diff --git a/internal/validate/formvalidation.go b/internal/validate/formvalidation.go index 9f61578e7..7215c8fcd 100644 --- a/internal/validate/formvalidation.go +++ b/internal/validate/formvalidation.go @@ -23,6 +23,7 @@ import ( "fmt" "net/mail" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/regexes" pwv "github.com/wagslane/go-password-validator" "golang.org/x/text/language" @@ -126,8 +127,14 @@ func Note(note string) error { // Privacy checks that the desired privacy setting is valid func Privacy(privacy string) error { - // TODO: add some validation logic here -- length, characters, etc - return nil + if privacy == "" { + return fmt.Errorf("empty string for privacy not allowed") + } + switch apimodel.Visibility(privacy) { + case apimodel.VisibilityDirect, apimodel.VisibilityMutualsOnly, apimodel.VisibilityPrivate, apimodel.VisibilityPublic, apimodel.VisibilityUnlisted: + return nil + } + return fmt.Errorf("privacy %s was not recognized", privacy) } // EmojiShortcode just runs the given shortcode through the regular expression diff --git a/testrig/util.go b/testrig/util.go index 0fb8aa887..0410366e3 100644 --- a/testrig/util.go +++ b/testrig/util.go @@ -34,18 +34,21 @@ import ( // req.Header.Set("Content-Type", w.FormDataContentType()) func CreateMultipartFormData(fieldName string, fileName string, extraFields map[string]string) (bytes.Buffer, *multipart.Writer, error) { var b bytes.Buffer - var err error + w := multipart.NewWriter(&b) var fw io.Writer - file, err := os.Open(fileName) - if err != nil { - return b, nil, err - } - if fw, err = w.CreateFormFile(fieldName, file.Name()); err != nil { - return b, nil, err - } - if _, err = io.Copy(fw, file); err != nil { - return b, nil, err + + if fileName != "" { + file, err := os.Open(fileName) + if err != nil { + return b, nil, err + } + if fw, err = w.CreateFormFile(fieldName, file.Name()); err != nil { + return b, nil, err + } + if _, err = io.Copy(fw, file); err != nil { + return b, nil, err + } } for k, v := range extraFields {