[performance] simpler throttling logic (#2407)
* reduce complexity of throttling logic to use 1 queue and an atomic int * use atomic add instead of CAS, add throttling test
This commit is contained in:
parent
1312695c46
commit
d56a8d095e
|
@ -42,6 +42,12 @@ var (
|
||||||
StatusInternalServerErrorJSON = mustJSON(map[string]string{
|
StatusInternalServerErrorJSON = mustJSON(map[string]string{
|
||||||
"status": http.StatusText(http.StatusInternalServerError),
|
"status": http.StatusText(http.StatusInternalServerError),
|
||||||
})
|
})
|
||||||
|
ErrorCapacityExceeded = mustJSON(map[string]string{
|
||||||
|
"error": "server capacity exceeded!",
|
||||||
|
})
|
||||||
|
ErrorRateLimitReached = mustJSON(map[string]string{
|
||||||
|
"error": "rate limit reached!",
|
||||||
|
})
|
||||||
EmptyJSONObject = mustJSON("{}")
|
EmptyJSONObject = mustJSON("{}")
|
||||||
EmptyJSONArray = mustJSON("[]")
|
EmptyJSONArray = mustJSON("[]")
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,8 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||||
"github.com/ulule/limiter/v3"
|
"github.com/ulule/limiter/v3"
|
||||||
"github.com/ulule/limiter/v3/drivers/store/memory"
|
"github.com/ulule/limiter/v3/drivers/store/memory"
|
||||||
|
|
||||||
|
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const rateLimitPeriod = 5 * time.Minute
|
const rateLimitPeriod = 5 * time.Minute
|
||||||
|
@ -141,10 +143,12 @@ func RateLimit(limit int, exceptions []string) gin.HandlerFunc {
|
||||||
if context.Reached {
|
if context.Reached {
|
||||||
// Return JSON error message for
|
// Return JSON error message for
|
||||||
// consistency with other endpoints.
|
// consistency with other endpoints.
|
||||||
c.AbortWithStatusJSON(
|
apiutil.Data(c,
|
||||||
http.StatusTooManyRequests,
|
http.StatusTooManyRequests,
|
||||||
gin.H{"error": "rate limit reached"},
|
apiutil.AppJSON,
|
||||||
|
apiutil.ErrorRateLimitReached,
|
||||||
)
|
)
|
||||||
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,12 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// token represents a request that is being processed.
|
// token represents a request that is being processed.
|
||||||
|
@ -80,55 +83,61 @@ func Throttle(cpuMultiplier int, retryAfter time.Duration) gin.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
limit = runtime.GOMAXPROCS(0) * cpuMultiplier
|
limit = runtime.GOMAXPROCS(0) * cpuMultiplier
|
||||||
backlogLimit = limit * cpuMultiplier
|
queueLimit = limit * cpuMultiplier
|
||||||
backlogChannelSize = limit + backlogLimit
|
tokens = make(chan token, limit)
|
||||||
tokens = make(chan token, limit)
|
requestCount = atomic.Int64{}
|
||||||
backlogTokens = make(chan token, backlogChannelSize)
|
retryAfterStr = strconv.FormatUint(uint64(retryAfter/time.Second), 10)
|
||||||
retryAfterStr = strconv.FormatUint(uint64(retryAfter/time.Second), 10)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// prefill token channels
|
// prefill token channel
|
||||||
for i := 0; i < limit; i++ {
|
for i := 0; i < limit; i++ {
|
||||||
tokens <- token{}
|
tokens <- token{}
|
||||||
}
|
}
|
||||||
for i := 0; i < backlogChannelSize; i++ {
|
|
||||||
backlogTokens <- token{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// inside this select, the caller tries to get a backlog token
|
// Always decrement request counter.
|
||||||
select {
|
defer func() { requestCount.Add(-1) }()
|
||||||
case <-c.Request.Context().Done():
|
|
||||||
// request context has been canceled already
|
// Increment request count.
|
||||||
|
n := requestCount.Add(1)
|
||||||
|
|
||||||
|
// Check whether the request
|
||||||
|
// count is over queue limit.
|
||||||
|
if n > int64(queueLimit) {
|
||||||
|
c.Header("Retry-After", retryAfterStr)
|
||||||
|
apiutil.Data(c,
|
||||||
|
http.StatusTooManyRequests,
|
||||||
|
apiutil.AppJSON,
|
||||||
|
apiutil.ErrorCapacityExceeded,
|
||||||
|
)
|
||||||
|
c.Abort()
|
||||||
return
|
return
|
||||||
case btok := <-backlogTokens:
|
}
|
||||||
|
|
||||||
|
// Sit and wait in the
|
||||||
|
// queue for free token.
|
||||||
|
select {
|
||||||
|
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
// request context has
|
||||||
|
// been canceled already.
|
||||||
|
return
|
||||||
|
|
||||||
|
case tok := <-tokens:
|
||||||
|
// caller has successfully
|
||||||
|
// received a token, allowing
|
||||||
|
// request to be processed.
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// when we're finished, return the backlog token to the bucket
|
// when we're finished, return
|
||||||
backlogTokens <- btok
|
// this token to the bucket.
|
||||||
|
tokens <- tok
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// inside *this* select, the caller has a backlog token,
|
// Process
|
||||||
// and they're waiting for their turn to be processed
|
// request!
|
||||||
select {
|
c.Next()
|
||||||
case <-c.Request.Context().Done():
|
|
||||||
// the request context has been canceled already
|
|
||||||
return
|
|
||||||
case tok := <-tokens:
|
|
||||||
// the caller gets a token, so their request can now be processed
|
|
||||||
defer func() {
|
|
||||||
// whatever happens to the request, put the
|
|
||||||
// token back in the bucket when we're finished
|
|
||||||
tokens <- tok
|
|
||||||
}()
|
|
||||||
c.Next() // <- finally process the caller's request
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
// we don't have space in the backlog queue
|
|
||||||
c.Header("Retry-After", retryAfterStr)
|
|
||||||
c.JSON(http.StatusTooManyRequests, gin.H{"error": "server capacity exceeded"})
|
|
||||||
c.Abort()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
// GoToSocial
|
||||||
|
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||||
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
/*
|
||||||
|
The code in this file is adapted from MIT-licensed code in github.com/go-chi/chi. Thanks chi (thi)!
|
||||||
|
|
||||||
|
See: https://github.com/go-chi/chi/blob/e6baba61759b26ddf7b14d1e02d1da81a4d76c08/middleware/throttle.go
|
||||||
|
|
||||||
|
And: https://github.com/sponsors/pkieltyka
|
||||||
|
*/
|
||||||
|
|
||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestThrottlingMiddleware(t *testing.T) {
|
||||||
|
testThrottlingMiddleware(t, 2, time.Second*10)
|
||||||
|
testThrottlingMiddleware(t, 4, time.Second*15)
|
||||||
|
testThrottlingMiddleware(t, 8, time.Second*30)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testThrottlingMiddleware(t *testing.T, cpuMulti int, retryAfter time.Duration) {
|
||||||
|
// Calculate expected request limit + queue.
|
||||||
|
limit := runtime.GOMAXPROCS(0) * cpuMulti
|
||||||
|
queueLimit := limit * cpuMulti
|
||||||
|
|
||||||
|
// Calculate expected retry-after header string.
|
||||||
|
retryAfterStr := strconv.FormatUint(uint64(retryAfter/time.Second), 10)
|
||||||
|
|
||||||
|
// Gin test http engine
|
||||||
|
// (used for ctx init).
|
||||||
|
e := gin.New()
|
||||||
|
|
||||||
|
// Add middleware to the gin engine handler stack.
|
||||||
|
middleware := middleware.Throttle(cpuMulti, retryAfter)
|
||||||
|
e.Use(middleware)
|
||||||
|
|
||||||
|
// Set the blocking gin handler.
|
||||||
|
handler := blockingHandler()
|
||||||
|
e.Handle("GET", "/", handler)
|
||||||
|
|
||||||
|
var cncls []func()
|
||||||
|
|
||||||
|
for i := 0; i < queueLimit+limit; i++ {
|
||||||
|
// Prepare a gin test context.
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Wrap request with new cancel context.
|
||||||
|
ctx, cncl := context.WithCancel(r.Context())
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
// Pass req through
|
||||||
|
// engine handler.
|
||||||
|
go e.ServeHTTP(rw, r)
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
// Get http result.
|
||||||
|
res := rw.Result()
|
||||||
|
|
||||||
|
if i < queueLimit {
|
||||||
|
|
||||||
|
// Check status == 200 (default, i.e not set).
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add cancel to func slice.
|
||||||
|
cncls = append(cncls, cncl)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// Check the returned status code is expected.
|
||||||
|
if res.StatusCode != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("did not return status 429 (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the returned retry-after header is set.
|
||||||
|
if res.Header.Get("Retry-After") != retryAfterStr {
|
||||||
|
t.Fatalf("did not return retry-after %s with queueLimit=%d and request=%d", retryAfterStr, queueLimit, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel on return.
|
||||||
|
defer cncl()
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel all blocked reqs.
|
||||||
|
for _, cncl := range cncls {
|
||||||
|
cncl()
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
// Check a bunchh more requests
|
||||||
|
// can now make it through after
|
||||||
|
// previous requests were released!
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
|
||||||
|
// Prepare a gin test context.
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Pass req through
|
||||||
|
// engine handler.
|
||||||
|
go e.ServeHTTP(rw, r)
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
// Get http result.
|
||||||
|
res := rw.Result()
|
||||||
|
|
||||||
|
// Check status == 200 (default, i.e not set).
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func blockingHandler() gin.HandlerFunc {
|
||||||
|
return func(ctx *gin.Context) {
|
||||||
|
<-ctx.Done()
|
||||||
|
ctx.Status(201) // specifically not 200
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue