[bugfix] Update poll delete/update db queries (#2361)
This commit is contained in:
parent
8d0c017cf2
commit
0b99f14d64
|
@ -44,7 +44,7 @@ func init() {
|
||||||
Table("polls").
|
Table("polls").
|
||||||
Column("expires_at_new").
|
Column("expires_at_new").
|
||||||
Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")).
|
Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")).
|
||||||
Where("1"). // bun gets angry performing update over all rows
|
Where("TRUE"). // bun gets angry performing update over all rows
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -341,9 +341,12 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
|
||||||
|
|
||||||
var poll gtsmodel.Poll
|
var poll gtsmodel.Poll
|
||||||
|
|
||||||
// Select poll counts from DB.
|
// Select current poll counts from DB,
|
||||||
|
// taking minimal columns needed to
|
||||||
|
// increment/decrement votes.
|
||||||
if err := tx.NewSelect().
|
if err := tx.NewSelect().
|
||||||
Model(&poll).
|
Model(&poll).
|
||||||
|
Column("options", "votes", "voters").
|
||||||
Where("? = ?", bun.Ident("id"), vote.PollID).
|
Where("? = ?", bun.Ident("id"), vote.PollID).
|
||||||
Scan(ctx); err != nil {
|
Scan(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -365,31 +368,35 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
|
||||||
|
|
||||||
func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
|
func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
|
||||||
err := p.db.RunInTx(ctx, func(tx Tx) error {
|
err := p.db.RunInTx(ctx, func(tx Tx) error {
|
||||||
// Delete all vote in poll,
|
// Delete all votes in poll.
|
||||||
// returning all vote choices.
|
res, err := tx.NewDelete().
|
||||||
switch _, err := tx.NewDelete().
|
|
||||||
Table("poll_votes").
|
Table("poll_votes").
|
||||||
Where("? = ?", bun.Ident("poll_id"), pollID).
|
Where("? = ?", bun.Ident("poll_id"), pollID).
|
||||||
Exec(ctx); {
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
case err == nil:
|
// irrecoverable
|
||||||
// no issue.
|
|
||||||
|
|
||||||
case errors.Is(err, db.ErrNoEntries):
|
|
||||||
// no votes found,
|
|
||||||
// return here.
|
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// irrecoverable.
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var poll gtsmodel.Poll
|
ra, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
// irrecoverable
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Select poll counts from DB.
|
if ra == 0 {
|
||||||
|
// No poll votes deleted,
|
||||||
|
// nothing to update.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select current poll counts from DB,
|
||||||
|
// taking minimal columns needed to
|
||||||
|
// increment/decrement votes.
|
||||||
|
var poll gtsmodel.Poll
|
||||||
switch err := tx.NewSelect().
|
switch err := tx.NewSelect().
|
||||||
Model(&poll).
|
Model(&poll).
|
||||||
|
Column("options", "votes", "voters").
|
||||||
Where("? = ?", bun.Ident("id"), pollID).
|
Where("? = ?", bun.Ident("id"), pollID).
|
||||||
Scan(ctx); {
|
Scan(ctx); {
|
||||||
|
|
||||||
|
@ -410,7 +417,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
|
||||||
poll.ResetVotes()
|
poll.ResetVotes()
|
||||||
|
|
||||||
// Finally, update the poll entry.
|
// Finally, update the poll entry.
|
||||||
_, err := tx.NewUpdate().
|
_, err = tx.NewUpdate().
|
||||||
Model(&poll).
|
Model(&poll).
|
||||||
Column("votes", "voters").
|
Column("votes", "voters").
|
||||||
Where("? = ?", bun.Ident("id"), pollID).
|
Where("? = ?", bun.Ident("id"), pollID).
|
||||||
|
@ -432,35 +439,37 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
|
||||||
|
|
||||||
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
|
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
|
||||||
err := p.db.RunInTx(ctx, func(tx Tx) error {
|
err := p.db.RunInTx(ctx, func(tx Tx) error {
|
||||||
var choices []int
|
// Slice should only ever be of length
|
||||||
|
// 0 or 1; it's a slice of slices only
|
||||||
|
// because we can't LIMIT deletes to 1.
|
||||||
|
var choicesSl [][]int
|
||||||
|
|
||||||
// Delete vote in poll by account,
|
// Delete vote in poll by account,
|
||||||
// returning the ID + choices of the vote.
|
// returning the ID + choices of the vote.
|
||||||
switch err := tx.NewDelete().
|
if err := tx.NewDelete().
|
||||||
Table("poll_votes").
|
Table("poll_votes").
|
||||||
Where("? = ?", bun.Ident("poll_id"), pollID).
|
Where("? = ?", bun.Ident("poll_id"), pollID).
|
||||||
Where("? = ?", bun.Ident("account_id"), accountID).
|
Where("? = ?", bun.Ident("account_id"), accountID).
|
||||||
Returning("choices").
|
Returning("?", bun.Ident("choices")).
|
||||||
Scan(ctx, &choices); {
|
Scan(ctx, &choicesSl); err != nil {
|
||||||
|
|
||||||
case err == nil:
|
|
||||||
// no issue.
|
|
||||||
|
|
||||||
case errors.Is(err, db.ErrNoEntries):
|
|
||||||
// no votes found,
|
|
||||||
// return here.
|
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// irrecoverable.
|
// irrecoverable.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var poll gtsmodel.Poll
|
if len(choicesSl) != 1 {
|
||||||
|
// No poll votes by this
|
||||||
|
// acct on this poll.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
choices := choicesSl[0]
|
||||||
|
|
||||||
// Select poll counts from DB.
|
// Select current poll counts from DB,
|
||||||
|
// taking minimal columns needed to
|
||||||
|
// increment/decrement votes.
|
||||||
|
var poll gtsmodel.Poll
|
||||||
switch err := tx.NewSelect().
|
switch err := tx.NewSelect().
|
||||||
Model(&poll).
|
Model(&poll).
|
||||||
|
Column("options", "votes", "voters").
|
||||||
Where("? = ?", bun.Ident("id"), pollID).
|
Where("? = ?", bun.Ident("id"), pollID).
|
||||||
Scan(ctx); {
|
Scan(ctx); {
|
||||||
|
|
||||||
|
@ -468,7 +477,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
|
||||||
// no issue.
|
// no issue.
|
||||||
|
|
||||||
case errors.Is(err, db.ErrNoEntries):
|
case errors.Is(err, db.ErrNoEntries):
|
||||||
// no votes found,
|
// no poll found,
|
||||||
// return here.
|
// return here.
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||||
|
@ -304,15 +305,64 @@ func (suite *PollTestSuite) TestDeletePollVotes() {
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
|
|
||||||
// Fetch latest version of poll from database.
|
// Fetch latest version of poll from database.
|
||||||
poll, err = suite.db.GetPollByID(ctx, poll.ID)
|
poll, err = suite.db.GetPollByID(
|
||||||
|
gtscontext.SetBarebones(ctx),
|
||||||
|
poll.ID,
|
||||||
|
)
|
||||||
suite.NoError(err)
|
suite.NoError(err)
|
||||||
|
|
||||||
// Check that poll counts are all zero.
|
// Check that poll counts are all zero.
|
||||||
suite.Equal(*poll.Voters, 0)
|
suite.Equal(*poll.Voters, 0)
|
||||||
suite.Equal(poll.Votes, make([]int, len(poll.Options)))
|
suite.Equal(make([]int, len(poll.Options)), poll.Votes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *PollTestSuite) TestDeletePollVotesNoPoll() {
|
||||||
|
// Create a new context for this test.
|
||||||
|
ctx, cncl := context.WithCancel(context.Background())
|
||||||
|
defer cncl()
|
||||||
|
|
||||||
|
// Try to delete votes of nonexistent poll.
|
||||||
|
nonPollID := "01HF6V4XWTSZWJ80JNPPDTD4DB"
|
||||||
|
|
||||||
|
err := suite.db.DeletePollVotes(ctx, nonPollID)
|
||||||
|
suite.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *PollTestSuite) TestDeletePollVotesBy() {
|
||||||
|
ctx, cncl := context.WithCancel(context.Background())
|
||||||
|
defer cncl()
|
||||||
|
|
||||||
|
for _, vote := range suite.testPollVotes {
|
||||||
|
// Fetch before version of pollBefore from database.
|
||||||
|
pollBefore, err := suite.db.GetPollByID(ctx, vote.PollID)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
// Delete this poll vote.
|
||||||
|
err = suite.db.DeletePollVoteBy(ctx, vote.PollID, vote.AccountID)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
// Fetch after version of poll from database.
|
||||||
|
pollAfter, err := suite.db.GetPollByID(ctx, vote.PollID)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
// Voters count should be reduced by 1.
|
||||||
|
suite.Equal(*pollBefore.Voters-1, *pollAfter.Voters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *PollTestSuite) TestDeletePollVotesByNoAccount() {
|
||||||
|
ctx, cncl := context.WithCancel(context.Background())
|
||||||
|
defer cncl()
|
||||||
|
|
||||||
|
// Try to delete a poll by nonexisting account.
|
||||||
|
pollID := suite.testPolls["local_account_1_status_6_poll"].ID
|
||||||
|
nonAccountID := "01HF6T545G1G8ZNMY1S3ZXJ608"
|
||||||
|
|
||||||
|
err := suite.db.DeletePollVoteBy(ctx, pollID, nonAccountID)
|
||||||
|
suite.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPollTestSuite(t *testing.T) {
|
func TestPollTestSuite(t *testing.T) {
|
||||||
suite.Run(t, new(PollTestSuite))
|
suite.Run(t, new(PollTestSuite))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue