mirror of
1
Fork 0

[bugfix] Update poll delete/update db queries (#2361)

This commit is contained in:
tobi 2023-11-14 13:43:27 +01:00 committed by GitHub
parent 8d0c017cf2
commit 0b99f14d64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 98 additions and 39 deletions

View File

@ -44,7 +44,7 @@ func init() {
Table("polls"). Table("polls").
Column("expires_at_new"). Column("expires_at_new").
Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")). Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")).
Where("1"). // bun gets angry performing update over all rows Where("TRUE"). // bun gets angry performing update over all rows
Exec(ctx); err != nil { Exec(ctx); err != nil {
return err return err
} }

View File

@ -341,9 +341,12 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
var poll gtsmodel.Poll var poll gtsmodel.Poll
// Select poll counts from DB. // Select current poll counts from DB,
// taking minimal columns needed to
// increment/decrement votes.
if err := tx.NewSelect(). if err := tx.NewSelect().
Model(&poll). Model(&poll).
Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), vote.PollID). Where("? = ?", bun.Ident("id"), vote.PollID).
Scan(ctx); err != nil { Scan(ctx); err != nil {
return err return err
@ -365,31 +368,35 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error { err := p.db.RunInTx(ctx, func(tx Tx) error {
// Delete all vote in poll, // Delete all votes in poll.
// returning all vote choices. res, err := tx.NewDelete().
switch _, err := tx.NewDelete().
Table("poll_votes"). Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID). Where("? = ?", bun.Ident("poll_id"), pollID).
Exec(ctx); { Exec(ctx)
if err != nil {
case err == nil: // irrecoverable
// no issue.
case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
return nil
default:
// irrecoverable.
return err return err
} }
var poll gtsmodel.Poll ra, err := res.RowsAffected()
if err != nil {
// irrecoverable
return err
}
// Select poll counts from DB. if ra == 0 {
// No poll votes deleted,
// nothing to update.
return nil
}
// Select current poll counts from DB,
// taking minimal columns needed to
// increment/decrement votes.
var poll gtsmodel.Poll
switch err := tx.NewSelect(). switch err := tx.NewSelect().
Model(&poll). Model(&poll).
Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), pollID). Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); { Scan(ctx); {
@ -410,7 +417,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
poll.ResetVotes() poll.ResetVotes()
// Finally, update the poll entry. // Finally, update the poll entry.
_, err := tx.NewUpdate(). _, err = tx.NewUpdate().
Model(&poll). Model(&poll).
Column("votes", "voters"). Column("votes", "voters").
Where("? = ?", bun.Ident("id"), pollID). Where("? = ?", bun.Ident("id"), pollID).
@ -432,35 +439,37 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error { err := p.db.RunInTx(ctx, func(tx Tx) error {
var choices []int // Slice should only ever be of length
// 0 or 1; it's a slice of slices only
// because we can't LIMIT deletes to 1.
var choicesSl [][]int
// Delete vote in poll by account, // Delete vote in poll by account,
// returning the ID + choices of the vote. // returning the ID + choices of the vote.
switch err := tx.NewDelete(). if err := tx.NewDelete().
Table("poll_votes"). Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID). Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID). Where("? = ?", bun.Ident("account_id"), accountID).
Returning("choices"). Returning("?", bun.Ident("choices")).
Scan(ctx, &choices); { Scan(ctx, &choicesSl); err != nil {
case err == nil:
// no issue.
case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
return nil
default:
// irrecoverable. // irrecoverable.
return err return err
} }
var poll gtsmodel.Poll if len(choicesSl) != 1 {
// No poll votes by this
// acct on this poll.
return nil
}
choices := choicesSl[0]
// Select poll counts from DB. // Select current poll counts from DB,
// taking minimal columns needed to
// increment/decrement votes.
var poll gtsmodel.Poll
switch err := tx.NewSelect(). switch err := tx.NewSelect().
Model(&poll). Model(&poll).
Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), pollID). Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); { Scan(ctx); {
@ -468,7 +477,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
// no issue. // no issue.
case errors.Is(err, db.ErrNoEntries): case errors.Is(err, db.ErrNoEntries):
// no votes found, // no poll found,
// return here. // return here.
return nil return nil

View File

@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
@ -304,15 +305,64 @@ func (suite *PollTestSuite) TestDeletePollVotes() {
suite.NoError(err) suite.NoError(err)
// Fetch latest version of poll from database. // Fetch latest version of poll from database.
poll, err = suite.db.GetPollByID(ctx, poll.ID) poll, err = suite.db.GetPollByID(
gtscontext.SetBarebones(ctx),
poll.ID,
)
suite.NoError(err) suite.NoError(err)
// Check that poll counts are all zero. // Check that poll counts are all zero.
suite.Equal(*poll.Voters, 0) suite.Equal(*poll.Voters, 0)
suite.Equal(poll.Votes, make([]int, len(poll.Options))) suite.Equal(make([]int, len(poll.Options)), poll.Votes)
} }
} }
func (suite *PollTestSuite) TestDeletePollVotesNoPoll() {
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Try to delete votes of nonexistent poll.
nonPollID := "01HF6V4XWTSZWJ80JNPPDTD4DB"
err := suite.db.DeletePollVotes(ctx, nonPollID)
suite.NoError(err)
}
func (suite *PollTestSuite) TestDeletePollVotesBy() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
for _, vote := range suite.testPollVotes {
// Fetch before version of pollBefore from database.
pollBefore, err := suite.db.GetPollByID(ctx, vote.PollID)
suite.NoError(err)
// Delete this poll vote.
err = suite.db.DeletePollVoteBy(ctx, vote.PollID, vote.AccountID)
suite.NoError(err)
// Fetch after version of poll from database.
pollAfter, err := suite.db.GetPollByID(ctx, vote.PollID)
suite.NoError(err)
// Voters count should be reduced by 1.
suite.Equal(*pollBefore.Voters-1, *pollAfter.Voters)
}
}
func (suite *PollTestSuite) TestDeletePollVotesByNoAccount() {
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Try to delete a poll by nonexisting account.
pollID := suite.testPolls["local_account_1_status_6_poll"].ID
nonAccountID := "01HF6T545G1G8ZNMY1S3ZXJ608"
err := suite.db.DeletePollVoteBy(ctx, pollID, nonAccountID)
suite.NoError(err)
}
func TestPollTestSuite(t *testing.T) { func TestPollTestSuite(t *testing.T) {
suite.Run(t, new(PollTestSuite)) suite.Run(t, new(PollTestSuite))
} }